Commit ffc39ba0 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Stronger acquire scope for low-latency kernels

parent 7d52ad72
...@@ -260,7 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, ...@@ -260,7 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_recv_tokens, recv_token_begin_idx; int num_recv_tokens, recv_token_begin_idx;
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
if (sub_warp_id == 1 and lane_id == 0) { if (sub_warp_id == 1 and lane_id == 0) {
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
num_recv_tokens = -num_recv_tokens - 1; num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
...@@ -450,7 +450,7 @@ combine(void* combined_x, ...@@ -450,7 +450,7 @@ combine(void* combined_x,
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
if (sub_warp_id == 0 and lane_id == 0) if (sub_warp_id == 0 and lane_id == 0)
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0); while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
} }
cg::this_grid().sync(); cg::this_grid().sync();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment