"git@developer.sourcefind.cn:change/sglang.git" did not exist on "b688fd858d89258ba0018e5903c2907badf49afa"
Unverified Commit e3b8a722 authored by xutizhou's avatar xutizhou Committed by GitHub
Browse files

[fix] illegal memory in _fwd_kernel_ep_scatter_2 and _fwd_kernel_ep_gather (#6348)

parent 3cf1473a
...@@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2( ...@@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE mask_s = offset_in_s < SCALE_HIDDEN_SIZE
for token_id in range(start_token_id, total_token_num, grid_num): for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load( to_copy_s = tl.load(
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
) )
for topk_index in tl.range(0, topk_num, 1, num_stages=4): for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if expert_id >= 0: if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index = dest_token_index_int32.to(tl.int64)
tl.store( tl.store(
output_index + token_id * output_index_stride0 + topk_index, output_index + token_id * output_index_stride0 + topk_index,
dest_token_index, dest_token_index_int32,
) )
output_tensor_ptr = ( output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0 output_tensor + dest_token_index * output_tensor_stride0
...@@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather( ...@@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather(
topk_num: tl.constexpr, topk_num: tl.constexpr,
BLOCK_D: tl.constexpr, BLOCK_D: tl.constexpr,
): ):
cur_block = tl.program_id(0) cur_block_int32 = tl.program_id(0)
start_cur_token = tl.program_id(1) cur_block = cur_block_int32.to(tl.int64)
start_cur_token_int32 = tl.program_id(1)
grid_num = tl.num_programs(1) grid_num = tl.num_programs(1)
for cur_token in range(start_cur_token, total_token_num, grid_num): for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
cur_token = cur_token_int32.to(tl.int64)
off_d = tl.arange(0, BLOCK_D) off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index in range(0, topk_num):
for topk_index_int32 in range(0, topk_num):
topk_index = topk_index_int32.to(tl.int64)
expert_id = tl.load( expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
) )
if expert_id >= 0: if expert_id >= 0:
source_token_index = tl.load( source_token_index_int32 = tl.load(
input_index + cur_token * input_index_stride0 + topk_index input_index + cur_token * input_index_stride0 + topk_index
) )
source_token_index = source_token_index_int32.to(tl.int64)
acc_weight = tl.load( acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment