Unverified Commit e6012370 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Fix for data error and kernel hung because of inflight rdma channel head update (#310)

Fix for data error and kernel hung because of inflight rdma channel head update
parents 0eee87b8 b65b22ed
...@@ -102,6 +102,17 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -102,6 +102,17 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Global barrier: the first warp does intra-node sync, the second warp does internode sync // Global barrier: the first warp does intra-node sync, the second warp does internode sync
EP_DEVICE_ASSERT(num_warps > 1); EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
// waiting for all previous inflight wrs to complete,
// in case of rewriting cleared rdma_buffer
auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized;
for (int i = thread_id; i < qps_per_rdma_rank * (kNumRDMARanks - 1); i += num_threads) {
auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % kNumRDMARanks;
auto qp_id = i % qps_per_rdma_rank;
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), qp_id);
}
__syncthreads();
if (thread_id == 32) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
...@@ -1044,9 +1055,18 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in ...@@ -1044,9 +1055,18 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
auto nvl_rank = rank % NUM_MAX_NVL_PEERS; auto nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto rdma_rank = rank / NUM_MAX_NVL_PEERS;
// Using two SMs, which clean the RDMA/NVL buffer respectively // Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) { if (sm_id == 0) {
auto qps_per_rdma_rank = ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized;
for (int i = thread_id; i < qps_per_rdma_rank * (num_rdma_ranks - 1); i += num_threads) {
auto dst_rdma_rank = (i / qps_per_rdma_rank + rdma_rank + 1) % num_rdma_ranks;
auto qp_id = i % qps_per_rdma_rank;
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), qp_id);
}
__syncthreads();
// Barrier for RDMA // Barrier for RDMA
if (thread_id == 32) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment