"vscode:/vscode.git/clone" did not exist on "7d0c2729399c3ce019a30fc175b973e892fd5fc3"
Unverified Commit e7234756 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

[hotfix] fix row_idx overflow in triton softmax (#80)

parent ec9352d1
...@@ -50,7 +50,7 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_ ...@@ -50,7 +50,7 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_
def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride, def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr, output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr): use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0) row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE) col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + row_idx * input_row_stride input_row_ptr = input_ptr + row_idx * input_row_stride
...@@ -77,7 +77,7 @@ def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_ro ...@@ -77,7 +77,7 @@ def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_ro
def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride, def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr, output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr): use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0) row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE) col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + 2 * row_idx * input_row_stride input_row_ptr = input_ptr + 2 * row_idx * input_row_stride
...@@ -119,7 +119,7 @@ def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_ ...@@ -119,7 +119,7 @@ def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_
BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr, BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr,
use_mask: tl.constexpr): use_mask: tl.constexpr):
row_idx = tl.program_id(0) row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE) col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + row_idx * output_row_stride output_row_ptr = output_ptr + row_idx * output_row_stride
...@@ -145,7 +145,7 @@ def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mas ...@@ -145,7 +145,7 @@ def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mas
n_cols, n_heads, BLOCK_SIZE: tl.constexpr, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
is_bf16: tl.constexpr, use_mask: tl.constexpr): is_bf16: tl.constexpr, use_mask: tl.constexpr):
row_idx = tl.program_id(0) row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE) col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment