"platforms/reference/vscode:/vscode.git/clone" did not exist on "9e56b300f22a7e63335ed930f1024f22a05b8c4d"
Unverified Commit e6012370 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Fix for data error and kernel hung because of inflight rdma channel head update (#310)

Fix for data error and kernel hung because of inflight rdma channel head update
parents 0eee87b8 b65b22ed
...@@ -102,6 +102,17 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -102,6 +102,17 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Global barrier: the first warp does intra-node sync, the second warp does internode sync // Global barrier: the first warp does intra-node sync, the second warp does internode sync
EP_DEVICE_ASSERT(num_warps > 1); EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
// waiting for all previous inflight wrs to complete,
// in case of rewriting cleared rdma_buffer
auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized;
for (int i = thread_id; i < qps_per_rdma_rank * (kNumRDMARanks - 1); i += num_threads) {
auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % kNumRDMARanks;
auto qp_id = i % qps_per_rdma_rank;
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), qp_id);
}
__syncthreads();
if (thread_id == 32) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
...@@ -1044,9 +1055,18 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in ...@@ -1044,9 +1055,18 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
auto nvl_rank = rank % NUM_MAX_NVL_PEERS; auto nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto rdma_rank = rank / NUM_MAX_NVL_PEERS;
// Using two SMs, which clean the RDMA/NVL buffer respectively // Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) { if (sm_id == 0) {
auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized;
for (int i = thread_id; i < qps_per_rdma_rank * (num_rdma_ranks - 1); i += num_threads) {
auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % num_rdma_ranks;
auto qp_id = i % qps_per_rdma_rank;
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), qp_id);
}
__syncthreads();
// Barrier for RDMA // Barrier for RDMA
if (thread_id == 32) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment