Unverified Commit 7de7464e authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Remove memory fence in NVLink barrier. (#253)



* Remove memory fence in NVLink barrier.

* Move `__syncthread` and fence into barrier.

* Fix bugs

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 4e72eb39
......@@ -99,7 +99,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
// Send numbers of tokens per rank/expert to RDMA ranks
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
......@@ -199,8 +199,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
for (int i = 0; i < num_nvl_experts; ++ i)
nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
}
memory_fence();
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
// Reduce the number of tokens per rank/expert
......@@ -227,7 +225,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
}
// Finally barrier
__syncthreads();
if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
......@@ -1040,15 +1037,13 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
// Clean
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
#pragma unroll
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
memory_fence();
__syncthreads();
// Barrier again
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
......
......@@ -21,7 +21,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
if (sm_id == 0) {
// Barrier first
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
......@@ -41,7 +41,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
for (int i = 0; i < num_experts_per_rank; ++ i)
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
}
__syncthreads();
// Wait for all ranks to be finished
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
......@@ -80,8 +79,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
local_per_expert_buffer[i] = 0;
// Barrier
memory_fence();
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
int dst_rank = sm_id - 1;
......@@ -137,7 +134,7 @@ __global__ void
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
// A simplified version for cached handles
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
......@@ -148,8 +145,6 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[kNumRanks * kNumRanks + i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
......@@ -520,7 +515,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) {
// Barrier before cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
......@@ -528,8 +523,6 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
......
......@@ -438,15 +438,20 @@ __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value
}
}
template <int kNumRanks>
template <int kNumRanks, bool kSyncOnly = false>
__forceinline__ __device__ void
barrier_block(int** barrier_signal_ptrs, int rank) {
auto thread_id = static_cast<int>(threadIdx.x);
// For non-sync-only cases, the memory operations by other threads in the block must be visible to the `sys` scope
if constexpr (not kSyncOnly) {
memory_fence();
__syncthreads();
}
// Add self-ranks, sub other ranks
if (thread_id < kNumRanks) {
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
memory_fence();
atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
}
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment