Commit 079c5a4f authored by Shangyan Zhou's avatar Shangyan Zhou
Browse files

Fix

parent eb155da4
...@@ -1048,23 +1048,17 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in ...@@ -1048,23 +1048,17 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
// Using two SMs, which clean the RDMA/NVL buffer respectively // Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) { if (sm_id == 0) {
// Barrier for RDMA // Barrier for RDMA
if (thread_id == 0) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
// Clean RDMA buffer // Clean RDMA buffer
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr); auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
#pragma unroll #pragma unroll
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
__syncthreads();
// Barrier again
if (thread_id == 0)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS, true>(barrier_signal_ptrs, nvl_rank);
// Clean NVL buffer // Clean NVL buffer
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]); auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
...@@ -1074,8 +1068,8 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in ...@@ -1074,8 +1068,8 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
__syncthreads(); __syncthreads();
// Barrier again // Barrier again
if (warp_id == 1) if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx_warp<kLowLatencyMode>(rdma_team, rank, lane_id); nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else if (sm_id == 1) { } else if (sm_id == 1) {
if (is_cached_dispatch) if (is_cached_dispatch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment