You need to sign in or sign up before continuing.
Commit 1d3963d2 authored by Shangyan Zhou's avatar Shangyan Zhou
Browse files

Fix bar.sync

parent ef70b83e
...@@ -439,7 +439,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -439,7 +439,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
__shared__ int rdma_send_channel_lock[kNumRDMARanks]; __shared__ int rdma_send_channel_lock[kNumRDMARanks];
__shared__ int rdma_send_channel_tail[kNumRDMARanks]; __shared__ int rdma_send_channel_tail[kNumRDMARanks];
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; auto sync_rdma_sender_smem = []() { asm volatile("barrier.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
// TMA stuffs // TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
...@@ -457,7 +457,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -457,7 +457,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Forward warp synchronization // Forward warp synchronization
__shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
__shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; auto sync_forwarder_smem = []() { asm volatile("barrier.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); };
if (warp_role == WarpRole::kRDMASender) { if (warp_role == WarpRole::kRDMASender) {
// Get tasks // Get tasks
...@@ -1567,8 +1567,8 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1567,8 +1567,8 @@ combine(int4* combined_x, float* combined_topk_weights,
__shared__ volatile bool forwarder_retired[kNumForwarders]; __shared__ volatile bool forwarder_retired[kNumForwarders];
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" :: "r"((kNumForwarders + 1) * 32)); }; auto sync_forwarder_smem = [=]() { asm volatile("barrier.sync 0, %0;" :: "r"((kNumForwarders + 1) * 32)); };
auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMAReceivers + 1) * 32)); }; auto sync_rdma_receiver_smem = [=]() { asm volatile("barrier.sync 1, %0;" :: "r"((kNumRDMAReceivers + 1) * 32)); };
if (warp_role == WarpRole::kNVLAndRDMAForwarder) { if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
// Receive from NVL ranks and forward to RDMA ranks // Receive from NVL ranks and forward to RDMA ranks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment