Unverified Commit e3b8a722 authored by xutizhou's avatar xutizhou Committed by GitHub
Browse files

[fix] illegal memory in _fwd_kernel_ep_scatter_2 and _fwd_kernel_ep_gather (#6348)

parent 3cf1473a
......@@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
for token_id in range(start_token_id, total_token_num, grid_num):
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
)
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index = dest_token_index_int32.to(tl.int64)
tl.store(
output_index + token_id * output_index_stride0 + topk_index,
dest_token_index,
dest_token_index_int32,
)
output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0
......@@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather(
topk_num: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block = tl.program_id(0)
start_cur_token = tl.program_id(1)
cur_block_int32 = tl.program_id(0)
cur_block = cur_block_int32.to(tl.int64)
start_cur_token_int32 = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token in range(start_cur_token, total_token_num, grid_num):
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
cur_token = cur_token_int32.to(tl.int64)
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index in range(0, topk_num):
for topk_index_int32 in range(0, topk_num):
topk_index = topk_index_int32.to(tl.int64)
expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
)
if expert_id >= 0:
source_token_index = tl.load(
source_token_index_int32 = tl.load(
input_index + cur_token * input_index_stride0 + topk_index
)
source_token_index = source_token_index_int32.to(tl.int64)
acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment