Unverified Commit e30c36a3 authored by hx's avatar hx Committed by GitHub
Browse files

[PyTorch] fix int32 overflow in permute kernels (#2196)



* fix overflow of int32 in permute kernels
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent be7f43f1
...@@ -324,7 +324,8 @@ def _permute_kernel( ...@@ -324,7 +324,8 @@ def _permute_kernel(
pid_h = tl.program_id(1) pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size mask = cur_off < hidden_size
input_off = pid_t * stride_input_token + cur_off * stride_input_hidden src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE: if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim mask_scale = cur_off < scale_hidden_dim
...@@ -338,7 +339,7 @@ def _permute_kernel( ...@@ -338,7 +339,7 @@ def _permute_kernel(
for idx in tl.range(n_routed): for idx in tl.range(n_routed):
dst_row = tl.load( dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
) ).to(tl.int64)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE: if PERMUTE_SCALE:
permuted_scale_off = ( permuted_scale_off = (
...@@ -519,7 +520,7 @@ def _unpermute_kernel( ...@@ -519,7 +520,7 @@ def _unpermute_kernel(
for idx in tl.range(n_routed): for idx in tl.range(n_routed):
src_row = tl.load( src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
) ).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type) inp = inp.to(compute_type)
...@@ -550,7 +551,8 @@ def _unpermute_kernel( ...@@ -550,7 +551,8 @@ def _unpermute_kernel(
prob = tl.load(permuted_probs_ptr + permuted_prob_off) prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
accumulator = accumulator.to(data_type) accumulator = accumulator.to(data_type)
output_off = pid_t * stride_output_token + current_offset * stride_output_hidden dst_row = pid_t.to(tl.int64)
output_off = dst_row * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask) tl.store(output_ptr + output_off, accumulator, mask=mask)
...@@ -681,7 +683,7 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -681,7 +683,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
for idx in tl.range(n_routed): for idx in tl.range(n_routed):
dst_row = tl.load( dst_row = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
) ).to(tl.int64)
expert_idx = tl.load( expert_idx = tl.load(
row_id_map_ptr row_id_map_ptr
+ pid * stride_row_id_map_token + pid * stride_row_id_map_token
...@@ -692,8 +694,10 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -692,8 +694,10 @@ def _unpermute_bwd_with_merging_probs_kernel(
while current_start < hidden_size: while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE) current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size mask = current_offset < hidden_size
src_row = pid.to(tl.int64)
input_off = ( input_off = (
pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden src_row * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
) )
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
inp = inp.to(compute_type) inp = inp.to(compute_type)
...@@ -902,11 +906,11 @@ def _sort_chunks_by_map_kernel( ...@@ -902,11 +906,11 @@ def _sort_chunks_by_map_kernel(
pid_t = tl.program_id(0) pid_t = tl.program_id(0)
pid_h = tl.program_id(1) pid_h = tl.program_id(1)
if FORWARD: if FORWARD:
src_row = pid_t src_row = pid_t.to(tl.int64)
dst_row = tl.load(row_id_map_ptr + pid_t) dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
else: else:
src_row = tl.load(row_id_map_ptr + pid_t) src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
dst_row = pid_t dst_row = pid_t.to(tl.int64)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size mask = current_offset < hidden_size
input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment