Unverified Commit e7234756 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

[hotfix] fix row_idx overflow in triton softmax (#80)

parent ec9352d1
......@@ -50,7 +50,7 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_
def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0)
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + row_idx * input_row_stride
......@@ -77,7 +77,7 @@ def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_ro
def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0)
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + 2 * row_idx * input_row_stride
......@@ -119,7 +119,7 @@ def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_
BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr,
use_mask: tl.constexpr):
row_idx = tl.program_id(0)
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + row_idx * output_row_stride
......@@ -145,7 +145,7 @@ def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mas
n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
is_bf16: tl.constexpr, use_mask: tl.constexpr):
row_idx = tl.program_id(0)
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment