"vscode:/vscode.git/clone" did not exist on "808a9ed98da97fe226d4c08093044809b7b64306"
Unverified Commit 55d8073d authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Bugfix] ep_scatter kernel store-load race condition (#34991)


Signed-off-by: default avatarYifan Qiao <yifanqiao@berkeley.edu>
parent cd32d6f5
...@@ -76,9 +76,13 @@ def _fwd_kernel_ep_scatter_1( ...@@ -76,9 +76,13 @@ def _fwd_kernel_ep_scatter_1(
) )
tokens_per_expert = round_up_128(tokens_per_expert) tokens_per_expert = round_up_128(tokens_per_expert)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
cur_expert_start = tl.load(expert_start_loc + cur_expert) # Extract this block's offset from the register vector (warp shuffle,
# no global memory round-trip) then write it once to expert_start_loc.
cur_expert_start = tl.sum(
tl.where(offset_cumsum == cur_expert, cumsum, tl.zeros_like(cumsum))
)
tl.store(expert_start_loc + cur_expert, cur_expert_start)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start m_indices_start_ptr = m_indices + cur_expert_start
...@@ -87,7 +91,7 @@ def _fwd_kernel_ep_scatter_1( ...@@ -87,7 +91,7 @@ def _fwd_kernel_ep_scatter_1(
# any rows in the per-expert aligned region that do not correspond to # any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to # real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them # -1 so DeepGEMM can skip them
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): for start_m in tl.range(0, cur_expert_token_num, BLOCK_E):
offs = start_m + off_expert offs = start_m + off_expert
mask = offs < cur_expert_token_num mask = offs < cur_expert_token_num
tl.store( tl.store(
...@@ -186,6 +190,7 @@ def ep_scatter( ...@@ -186,6 +190,7 @@ def ep_scatter(
grid = num_experts grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0 assert m_indices.shape[0] % BLOCK_E == 0
assert expert_start_loc.shape[0] == num_experts
_fwd_kernel_ep_scatter_1[(grid,)]( _fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert, num_recv_tokens_per_expert,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment