Unverified Commit 7de7464e authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Remove memory fence in NVLink barrier. (#253)



* Remove memory fence in NVLink barrier.

* Move `__syncthread` and fence into barrier.

* Fix bugs

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 4e72eb39
...@@ -99,7 +99,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -99,7 +99,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == 32) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
// Send numbers of tokens per rank/expert to RDMA ranks // Send numbers of tokens per rank/expert to RDMA ranks
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr); auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
...@@ -199,8 +199,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -199,8 +199,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
for (int i = 0; i < num_nvl_experts; ++ i) for (int i = 0; i < num_nvl_experts; ++ i)
nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
} }
memory_fence();
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
// Reduce the number of tokens per rank/expert // Reduce the number of tokens per rank/expert
...@@ -227,7 +225,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in ...@@ -227,7 +225,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
} }
// Finally barrier // Finally barrier
__syncthreads();
if (thread_id == 32) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
...@@ -1040,15 +1037,13 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in ...@@ -1040,15 +1037,13 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) { } else if (sm_id == 1) {
// Barrier for NVL // Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
// Clean // Clean
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]); auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
#pragma unroll #pragma unroll
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
memory_fence();
__syncthreads();
// Barrier again // Barrier again
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
......
...@@ -21,7 +21,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, ...@@ -21,7 +21,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
if (sm_id == 0) { if (sm_id == 0) {
// Barrier first // Barrier first
barrier_block<kNumRanks>(barrier_signal_ptrs, rank); barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
int *per_rank_buffer, *per_expert_buffer; int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) { if (thread_id < kNumRanks) {
...@@ -41,7 +41,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, ...@@ -41,7 +41,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
for (int i = 0; i < num_experts_per_rank; ++ i) for (int i = 0; i < num_experts_per_rank; ++ i)
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
} }
__syncthreads();
// Wait for all ranks to be finished // Wait for all ranks to be finished
barrier_block<kNumRanks>(barrier_signal_ptrs, rank); barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
...@@ -80,8 +79,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, ...@@ -80,8 +79,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
local_per_expert_buffer[i] = 0; local_per_expert_buffer[i] = 0;
// Barrier // Barrier
memory_fence();
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank); barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else { } else {
int dst_rank = sm_id - 1; int dst_rank = sm_id - 1;
...@@ -137,7 +134,7 @@ __global__ void ...@@ -137,7 +134,7 @@ __global__ void
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) { void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
// A simplified version for cached handles // A simplified version for cached handles
barrier_block<kNumRanks>(barrier_signal_ptrs, rank); barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
// Copy and clean // Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x); auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
...@@ -148,8 +145,6 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, ...@@ -148,8 +145,6 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
#pragma unroll #pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads) for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[kNumRanks * kNumRanks + i] = 0; ptr[kNumRanks * kNumRanks + i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning // Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank); barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
...@@ -520,7 +515,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int ...@@ -520,7 +515,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) { if (sm_id == 0) {
// Barrier before cleaning // Barrier before cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank); barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
// Clean // Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x); auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
...@@ -528,8 +523,6 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int ...@@ -528,8 +523,6 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
#pragma unroll #pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads) for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0; ptr[i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning // Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank); barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
......
...@@ -438,15 +438,20 @@ __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value ...@@ -438,15 +438,20 @@ __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value
} }
} }
template <int kNumRanks> template <int kNumRanks, bool kSyncOnly = false>
__forceinline__ __device__ void __forceinline__ __device__ void
barrier_block(int** barrier_signal_ptrs, int rank) { barrier_block(int** barrier_signal_ptrs, int rank) {
auto thread_id = static_cast<int>(threadIdx.x); auto thread_id = static_cast<int>(threadIdx.x);
// For non-sync-only cases, the memory operations by other threads in the block must be visible to the `sys` scope
if constexpr (not kSyncOnly) {
memory_fence();
__syncthreads();
}
// Add self-ranks, sub other ranks // Add self-ranks, sub other ranks
if (thread_id < kNumRanks) { if (thread_id < kNumRanks) {
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG); atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
memory_fence();
atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG); atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
} }
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x); EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment