Commit 8a34a9bd authored by lishen's avatar lishen
Browse files

modify internode notify

parent 17d9c844
......@@ -116,13 +116,10 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
if (sm_id == 0) {
// Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
__syncthreads();
// Send numbers of tokens per rank/expert to RDMA ranks
auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
......@@ -152,14 +149,25 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
if (thread_id < kNumRDMARanks) {
shmem_int_put_nbi(
for (int i = warp_id; i < kNumRDMARanks; i += num_warps) {
if (i != rdma_rank) {
shmemx_int_put_nbi_warp(
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
rdma_recv_num_tokens_mixed.send_buffer(thread_id),
rdma_recv_num_tokens_mixed.send_buffer(i),
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank));
} else {
UNROLLED_WARP_COPY(1,
lane_id,
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
rdma_recv_num_tokens_mixed.send_buffer(i),
ld_volatile_global,
st_na_global);
}
}
__syncthreads();
if (thread_id == 0)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
......@@ -215,7 +223,6 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
}
// Send numbers of tokens per rank/expert to NVL ranks
EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);
if (thread_id < NUM_MAX_NVL_PEERS) {
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++i)
......@@ -225,10 +232,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] =
nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
}
memory_fence();
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
__syncthreads();
// Reduce number of tokens per rank/expert
EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
......@@ -255,7 +259,6 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
}
// Finally barrier
__syncthreads();
if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
......@@ -355,12 +358,13 @@ void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mappe
auto nvl_clean_meta =
get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
num_nvl_bytes);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
// add assert origin kernel
EP_HOST_ASSERT(num_rdma_ranks <= kNumThreads);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kNumThreads, "Assert NUM_MAX_NVL_PEERS <= kNumThreads");
// Launch kernel
SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
......@@ -1202,37 +1206,31 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) {
// Barrier for RDMA
if (thread_id == 0)
if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
// Clean
// Clean RDMA buffer
auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
shmem_fence();
__syncthreads();
// Barrier again
if (thread_id == 0)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
__syncthreads();
// Clean
// Clean NVL buffer
auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
memory_fence();
__syncthreads();
// Barrier again
if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
// Barrier again
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else if (sm_id == 2) {
} else if (sm_id == 1) {
if (is_cached_dispatch)
return;
......@@ -1265,10 +1263,11 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
constexpr int num_clean_sms = 2;
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks;
dst_rdma_rank += num_channels * 2 - 3) {
for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
dst_rdma_rank += num_channels * 2 - num_clean_sms) {
// Iterate in reverse order
int token_start_idx =
warp_id == 0
......@@ -1319,7 +1318,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_channels * 2 > 3);
EP_HOST_ASSERT(num_channels * 2 > 2);
// Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment