Unverified Commit e30c36a3 authored by hx's avatar hx Committed by GitHub
Browse files

[PyTorch] fix int32 overflow in permute kernels (#2196)



* fix overflow of int32 in permute kernels
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent be7f43f1
......@@ -324,7 +324,8 @@ def _permute_kernel(
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
input_off = pid_t * stride_input_token + cur_off * stride_input_hidden
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
......@@ -338,7 +339,7 @@ def _permute_kernel(
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
)
).to(tl.int64)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
......@@ -519,7 +520,7 @@ def _unpermute_kernel(
for idx in tl.range(n_routed):
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
)
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
......@@ -550,7 +551,8 @@ def _unpermute_kernel(
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
accumulator = accumulator.to(data_type)
output_off = pid_t * stride_output_token + current_offset * stride_output_hidden
dst_row = pid_t.to(tl.int64)
output_off = dst_row * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
......@@ -681,7 +683,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
)
).to(tl.int64)
expert_idx = tl.load(
row_id_map_ptr
+ pid * stride_row_id_map_token
......@@ -692,8 +694,10 @@ def _unpermute_bwd_with_merging_probs_kernel(
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
src_row = pid.to(tl.int64)
input_off = (
pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden
src_row * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
......@@ -902,11 +906,11 @@ def _sort_chunks_by_map_kernel(
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
if FORWARD:
src_row = pid_t
dst_row = tl.load(row_id_map_ptr + pid_t)
src_row = pid_t.to(tl.int64)
dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
else:
src_row = tl.load(row_id_map_ptr + pid_t)
dst_row = pid_t
src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
dst_row = pid_t.to(tl.int64)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment