#include "hip/hip_runtime.h" #include "buffer.cuh" #include "configs.cuh" #include "launch.cuh" #include "utils.cuh" #include "shmem_wrapper.cuh" #ifndef DISABLE_ROCSHMEM // 安全检查:确保宏已定义 #ifndef HIP_VERSION_PATCH #error "HIP_VERSION_PATCH not defined! Check your HIP installation." #endif // TODO: fix unroll warnings // #ifdef __clang__ // #pragma clang diagnostic push // #pragma clang diagnostic ignored "-Wpass-failed" // #pragma clang diagnostic ignored "-Wdeprecated-volatile" // #endif // __clang__ namespace deep_ep { namespace internode { extern shmem_team_t cpu_rdma_team; struct SourceMeta { int src_rdma_rank, is_token_in_nvl_rank_bits; // sizeof(SourceMeta) = 8 EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); __forceinline__ SourceMeta() = default; // TODO: faster encoding __device__ __forceinline__ SourceMeta(int rdma_rank, const bool *is_token_in_nvl_ranks) { src_rdma_rank = rdma_rank; is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; #pragma unroll for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i) is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; } __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; } }; int get_source_meta_bytes() { return sizeof(SourceMeta); } __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { return static_cast(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); } __host__ __device__ __forceinline__ std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_channels) { // Return `int32_t` offset and count to clean return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int), (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels}; } __host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_channels) { // Return `int32_t` offset and to clean EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); return { (num_nvl_recv_buffer_tokens * (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_nvl_ranks * num_channels) / sizeof(int), num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, }; } template __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) { return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank; } template __forceinline__ __device__ void dushmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) { // NOTE: shmem_device_barrier_all() might be an issue as // it doesn't follow OpenSHMEM specification on ROCm kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all(); } template __global__ void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks, const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped, const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped, int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels, int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, const int nvl_num_int_clean, int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank, const shmem_team_t rdma_team) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); auto num_threads = static_cast(blockDim.x), num_warps = num_threads / kWarpSize; auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; if (sm_id == 0) { // Communication with others // Global barrier: the first warp do intra-node sync, the second warp do internode sync if (thread_id == kWarpSize) dushmem_barrier_with_same_gpu_idx(rdma_team); barrier_block(barrier_signal_ptrs, nvl_rank); // Send numbers of tokens per rank/expert to RDMA ranks auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); auto rdma_recv_num_tokens_mixed = SymBuffer( rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks); // Clean up for later data dispatch EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; // Copy to send buffer for (int i = thread_id; i < num_ranks; i += num_threads) rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i]; for (int i = thread_id; i < num_experts; i += num_threads) rdma_recv_num_tokens_mixed.send_buffer( i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = num_tokens_per_expert[i]; if (thread_id < kNumRDMARanks) rdma_recv_num_tokens_mixed.send_buffer( thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id]; __syncthreads(); // Issue send // TODO: more light fence or barrier or signaling // TODO: overlap EP barrier and NVL cleaning for (int i = warp_id; i < kNumRDMARanks; i += num_warps) { if (i != rdma_rank) { shmemx_int_put_nbi_warp( rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(i), NUM_MAX_NVL_PEERS + num_rdma_experts + 1, translate_dst_rdma_rank(i, nvl_rank)); } else { UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(i), ld_volatile_global, st_na_global); } } __syncthreads(); if (thread_id == 0) dushmem_barrier_with_same_gpu_idx(rdma_team); __syncthreads(); // NVL buffers auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; auto nvl_reduced_num_tokens_per_expert = Buffer(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer); auto nvl_send_num_tokens_per_rank = AsymBuffer(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); // Clean up for later data dispatch auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int)); for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; // Reduce number of tokens per expert into the NVL send buffer // TODO: may use DUSHMEM reduction EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); if (thread_id < num_rdma_experts) { int sum = 0; #pragma unroll for (int i = 0; i < kNumRDMARanks; ++i) sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id]; nvl_reduced_num_tokens_per_expert[thread_id] = sum; } __syncthreads(); // Reduce RDMA received tokens if (thread_id == 0) { int sum = 0; #pragma unroll for (int i = 0; i < kNumRDMARanks; ++i) { sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; recv_rdma_rank_prefix_sum[i] = sum; } while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) ; *moe_recv_rdma_counter_mapped = sum; } // Send numbers of tokens per rank/expert to NVL ranks if (thread_id < NUM_MAX_NVL_PEERS) { #pragma unroll for (int i = 0; i < kNumRDMARanks; ++i) nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id]; for (int i = 0; i < num_nvl_experts; ++i) nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; } barrier_block(barrier_signal_ptrs, nvl_rank); // Reduce number of tokens per rank/expert EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); if (thread_id == 0) { int sum = 0; for (int i = 0; i < num_ranks; ++i) { int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS; sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; recv_gbl_rank_prefix_sum[i] = sum; } while (ld_volatile_global(moe_recv_counter_mapped) != -1) ; *moe_recv_counter_mapped = sum; } if (thread_id < num_nvl_experts) { int sum = 0; #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1) ; moe_recv_expert_counter_mapped[thread_id] = sum; } // Finally barrier if (thread_id == kWarpSize) dushmem_barrier_with_same_gpu_idx(rdma_team); barrier_block(barrier_signal_ptrs, nvl_rank); } else { // Calculate meta data int dst_rdma_rank = sm_id - 1; for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // Iterate over tokens int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize) { EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); auto is_token_in_rank_uint64 = *reinterpret_cast( is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); #pragma unroll for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) per_nvl_rank_count[j] += is_token_in_rank_values[j]; total_count += (is_token_in_rank_uint64 != 0); } // Warp reduce total_count = warp_reduce_sum(total_count); #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); // Write into channel matrix if (lane_id == 0) { #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i]; rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count; } } // Calculate prefix sum __syncthreads(); if (thread_id == 0) { auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels; for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; } EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers"); if (thread_id < NUM_MAX_NVL_PEERS) { auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels; for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; } } } void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks, const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped, const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped, int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels, int hidden_int4, int num_scales, int num_topk, int expert_alignment, int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank, hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode) { #define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ { \ auto notify_dispatch_func = low_latency_mode ? notify_dispatch \ : notify_dispatch; \ LAUNCH_KERNEL_NON_COOPERATIVE( \ &cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \ num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert, \ moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens, \ num_channels, expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second, \ nvl_clean_meta.first, nvl_clean_meta.second, rdma_channel_prefix_matrix, \ recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, cpu_rdma_team); \ } \ break constexpr int kNumThreads = 256; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Get clean meta auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); // add assert origin kernel EP_HOST_ASSERT(num_rdma_ranks <= kNumThreads); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kNumThreads, "Assert NUM_MAX_NVL_PEERS <= kNumThreads"); // Launch kernel SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); #undef NOTIFY_DISPATCH_LAUNCH_CASE } // At most 8 RDMA ranks to be sent constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { return num_rdma_ranks < 8 ? num_rdma_ranks : 8; } template __global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1) dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights, SourceMeta *recv_src_meta, const int4 *x, const float *x_scales, const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head, int *send_nvl_head, int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks) { enum class WarpRole { kRDMASender, // 从x写入到RDMA发送缓存 kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存 kRDMAAndNVLForwarder, // 从RDMA接收缓存转写到ipc nvl缓存 kForwarderCoordinator, // 向远端RDMA确认接收 kNVLReceivers // 从nvl缓存写入到recv_x }; #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) __shared__ shmem_ctx_t ctx; shmem_wg_ctx_create(&ctx); #endif const auto sm_id = static_cast(blockIdx.x); const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / kWarpSize; const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto num_channels = static_cast(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS); const auto role_meta = [=]() -> std::pair { if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) { if(warp_id < kNumDispatchRDMASenderWarps) { return {WarpRole::kRDMASender, -1}; } else if(warp_id == kNumDispatchRDMASenderWarps) { return {WarpRole::kRDMASenderCoordinator, -1}; } } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) { if(warp_id < NUM_MAX_NVL_PEERS) { return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; } else { return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; } } else { return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS}; } }(); auto warp_role = role_meta.first; auto target_rank = role_meta.second; // Not applicable for RDMA senders // RDMA symmetric layout auto hidden_bytes = hidden_int4 * sizeof(int4); auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); // NVL buffer layouts // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; int rs_wr_rank = 0, ws_rr_rank = 0; if (warp_role == WarpRole::kRDMAAndNVLForwarder) rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; if (warp_role == WarpRole::kNVLReceivers) rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; // Allocate buffers auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_src_meta = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_x_scales = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_topk_idx = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_topk_weights = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); // RDMA sender warp synchronization __shared__ volatile int rdma_send_next_token_idx; __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; // NVL and RDMA coordinate Forward warp synchronization __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; // Place the main logic of your kernel here, using the parameters above. if(warp_role == WarpRole::kRDMASender) { /* 这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。 它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。 然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。 最后,它复制相关的数据到对称发送缓冲区。 kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存, 期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。 同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。 */ // 获取任务范围 int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // 清理共享内存 EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量"); if(warp_id == 0 && lane_id == 0) { rdma_send_next_token_idx = token_start_idx; } if(warp_id == 0 && lane_id < kNumRDMARanks) { rdma_send_channel_tail[lane_id] = 0; rdma_send_channel_next_tail[lane_id] = 0; } // 发送本通道中的令牌数量,通过 `-value - 1` 表示 EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量"); // 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中 // 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); if (lane_id < NUM_MAX_NVL_PEERS) { dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; } syncwarp(); if (dst_rdma_rank != rdma_rank) { #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_int_put_nbi_warp(ctx, #else shmemx_int_put_nbi_warp( #endif rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); } } #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_quiet(ctx); #else shmem_fence(); #endif // sync_rdma_sender_smem(); __syncthreads(); // 遍历令牌并复制到缓冲区 int64_t token_idx; int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { // 读取RDMA秩的存在性 uint64_t is_token_in_rank_uint64 = 0; if(lane_id < kNumRDMARanks) { is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); } // 获得处理数据的自旋锁,获得锁后才会处理一些数据信息 while(lane_id == 0 && rdma_send_next_token_idx != token_idx) { // 等待 } syncwarp(); // 获取下一个尾部位置 int rdma_tail_idx = -1; if(is_token_in_rank_uint64 != 0) { rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++; // 与kForwarderCoordinator相互配合,调节发送数据的频率 while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); } } syncwarp(); // 存储RDMA头部以供合并 if(lane_id < kNumRDMARanks && !kCachedMode) { send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; } // 更新最后一个令牌尾部 if(last_rdma_tail_idx >= 0) { st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); } last_rdma_tail_idx = rdma_tail_idx; // 释放顺序锁 if(lane_id == 0) { rdma_send_next_token_idx += 1; } // 广播尾部位置 SourceMeta src_meta; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; void* dst_send_buffers[kNumTopkRDMARanks]; /* 该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作 */ #pragma unroll for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) { // 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值 if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) { // warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0 // 计算slot_idx在接收缓冲区中的位置 slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; // 存储当前RDMA秩到topk_ranks数组中 topk_ranks[num_topk_ranks] = i; // 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组 auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); auto recv_is_token_in_rank_values = reinterpret_cast(&recv_is_token_in_rank_uint64); // 如果当前lane_id等于num_topk_ranks,则更新src_meta if(lane_id == num_topk_ranks) { src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); } // 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中 // 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数 dst_send_buffers[num_topk_ranks++] = reinterpret_cast(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token; } } EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); //////////////// 复制数据到发送缓冲区 //////////////// // 复制源元数据到对称发送缓冲区 if(lane_id < num_topk_ranks) { st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), src_meta); } for(int i = 0; i < num_topk_ranks; ++i) { dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + 1; } // 复制 `x` 到对称发送缓冲区 auto st_broadcast = [=](const int key, const int4& value) { for(int j = 0; j < num_topk_ranks; ++j) { st_na_global(reinterpret_cast(dst_send_buffers[j]) + key, value); } }; UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast); for(int i = 0; i < num_topk_ranks; ++i) { dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; } // 复制 `x_scales` 到对称发送缓冲区 for(int i = lane_id; i < num_scales; i += kWarpSize) { auto value = ld_nc_global(x_scales + token_idx * num_scales + i); for(int j = 0; j < num_topk_ranks; ++j) { st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, value); } } for(int i = 0; i < num_topk_ranks; ++i) { dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; } // 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区 for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) { auto rank_idx = i / num_topk, copy_idx = i % num_topk; auto idx_value = static_cast(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); } } // 结尾部分 // 获取顺序锁 while(lane_id == 0 && rdma_send_next_token_idx != token_idx) { // 等待 } syncwarp(); // 更新最后一个令牌尾部 if(last_rdma_tail_idx >= 0) { st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); } // 释放顺序锁 if(lane_id == 0) { rdma_send_next_token_idx += 1; } } else if(warp_role == WarpRole::kRDMASenderCoordinator) { /* 这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。 它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。 如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。 最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。 kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32), dushmem内存一致性(dushmem_fence)和原子操作(dushmemx_signal_op),减少硬同步,提升整体效率。 */ if(warp_id > kNumDispatchRDMASenderWarps) { return; } // 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题 EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); // 同步共享内存,确保所有线程在继续之前都达到了这一点 // sync_rdma_sender_smem(); __syncthreads(); // 计算当前通道需要发送的令牌数 int num_tokens_to_send = 0; if(lane_id < kNumRDMARanks) { num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; if(channel_id > 0) num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; } // 记录上次发出的尾部位置 int last_issued_tail = 0; // 当有任何RDMA秩需要发送令牌时,继续循环 while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) { for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) { // 计算目标RDMA秩 int dst_rdma_rank = (i + channel_id) % kNumRDMARanks; // 获取同步后的需要发送的令牌数 synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank); if(synced_num_tokens_to_send == 0) continue; // 如果没有令牌需要发送,则跳过 // 读取进度 auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank); auto processed_tail = ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)); auto num_tokens_processed = processed_tail - synced_last_issued_tail; // 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过 if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens) continue; // 计算本次需要发出的令牌数 auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send); // 发出RDMA发送请求 if(dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_schar_put_nbi_warp(ctx, #else shmemx_int8_put_nbi_warp( #endif rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, num_bytes_per_rdma_token * num_tokens_to_issue, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_quiet(ctx); #else shmem_fence(); #endif } else { // 对于本地RDMA秩,使用较轻的内存屏障 memory_fence(); } // 更新尾部位置 syncwarp(); if(lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信 #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_ulong_atomic_add(ctx, #else shmem_signal_op_add( #endif rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); } } } // while(__any(num_tokens_to_send > 0)) } else if(warp_role == WarpRole::kRDMAAndNVLForwarder) { /* 这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。 它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。 接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。 然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。 最后,它同步头部和尾部索引,并标记通道为退役状态。 */ // RDMA消费者和NVL生产者 const auto dst_nvl_rank = target_rank; // 目标NVL秩 const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; // 目标秩 const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); // 目标秩专家开始 const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束 // 等待计数器到达 int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize); auto start_time = wall_clock64(); if(lane_id < kNumRDMARanks) { while(true) { // 对应于kRDMASender中的数据写入 auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); // 是nvl节点的起始地址 auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址 auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); // 本rdma节点的起始地址 auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); // 本节点的结束地址 if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) { // 通知NVL秩 int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum); st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); // 保存从RDMA通道接收的令牌计数 src_rdma_channel_prefix = -meta_2 - 1; auto src_rdma_channel_prefix_1 = -meta_3 - 1; num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量 if(!kCachedMode) recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中 EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); break; } // 超时检查 if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n", channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); trap(); } } } syncwarp(); // 移动缓存的头部 send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; // 等待共享内存被清理 // sync_forwarder_smem(); __syncthreads(); // 开始准备处理接受数据,直到所有的数据接受完成。 // 转发从RDMA缓冲区的令牌 // 注意:总是从本地秩开始 int src_rdma_rank = sm_id % kNumRDMARanks; int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) { // 检查nvl目标队列是否为空,或者等待一个缓冲区被释放 start_time = wall_clock64(); // 用于给kNVLReceivers进行互动,控制数据的传输速度 while(lane_id == 0) { int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) break; cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); // 超时检查 if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail); trap(); } } syncwarp(); // 找到下一个源RDMA秩(轮询) start_time = wall_clock64(); while(true) { src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail) cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) { break; } } // 超时检查 if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n", channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); trap(); } } auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank); auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank); // 遍历RDMA缓冲区中的每一个令牌 for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; // 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入 void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; auto src_meta = ld_nc_global(reinterpret_cast(reinterpret_cast(shifted))); if(lane_id == src_rdma_rank) { num_tokens_to_recv_from_rdma -= 1; } bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); if(lane_id == src_rdma_rank) { auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; rdma_nvl_token_idx += is_in_dst_nvl_rank; if(!kCachedMode) send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; } if(!is_in_dst_nvl_rank) continue; // 获取一个空闲槽位 int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens; // 设置 src和dst 位置 auto src_gpu_buffer_x = reinterpret_cast(reinterpret_cast(shifted) + sizeof(SourceMeta)); auto src_gpu_buffer_scales = reinterpret_cast(reinterpret_cast(src_gpu_buffer_x) + hidden_bytes); auto src_gpu_buffer_topk_idx = reinterpret_cast(reinterpret_cast(src_gpu_buffer_scales) + num_scales * sizeof(float)); auto src_gpu_buffer_topk_weights = reinterpret_cast(reinterpret_cast(src_gpu_buffer_topk_idx) + num_topk * sizeof(int)); auto dst_gpu_buffer_x = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; auto dst_gpu_buffer_scales = nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales; auto dst_gpu_buffer_topk_idx = nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk; auto dst_gpu_buffer_topk_weights = nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk; if(lane_id == 0) { st_na_global(reinterpret_cast(nvl_channel_src_meta.buffer() + dst_slot_idx), *reinterpret_cast(&src_meta)); } UNROLLED_WARP_COPY(5, lane_id, hidden_int4, dst_gpu_buffer_x, src_gpu_buffer_x, ld_direct_global, st_na_global); UNROLLED_WARP_COPY(1, lane_id, num_scales, dst_gpu_buffer_scales, src_gpu_buffer_scales, ld_direct_global, st_na_global); for(int t = lane_id; t < num_topk; t += kWarpSize) { int idx_val = ld_direct_global(reinterpret_cast(src_gpu_buffer_topk_idx) + t); float w_val = ld_direct_global(reinterpret_cast(src_gpu_buffer_topk_weights) + t); int new_idx = (idx_val >= dst_rank_expert_begin && idx_val < dst_rank_expert_end) ? (idx_val - dst_rank_expert_begin) : -1; float new_w = (new_idx != -1) ? w_val : 0.0f; dst_gpu_buffer_topk_idx[t] = new_idx; dst_gpu_buffer_topk_weights[t] = new_w; } // 在NVL缓冲区不足的情况下,提前停止 if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens) src_rdma_tail = i + 1; } // 同步头部索引 if(lane_id == src_rdma_rank) forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); // 移动尾部索引,与kNVLReceivers互相通信使用 syncwarp(); if(lane_id == 0) { st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); } } // Retired syncwarp(); if(lane_id == 0) { forward_channel_retired[dst_nvl_rank] = true; } } else if(warp_role == WarpRole::kForwarderCoordinator) { /* 这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。 它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。 然后,它清理共享内存,并初始化转发通道的头部和退役状态。 接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。 否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。 */ // Extra warps for forwarder coordinator should exit directly if (warp_id > NUM_MAX_NVL_PEERS) return; // 转发warp协调器 EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量"); // 清理共享内存 EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量"); #pragma unroll for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize) forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; if(lane_id < NUM_MAX_NVL_PEERS) forward_channel_retired[lane_id] = false; // sync_forwarder_smem(); __syncthreads(); int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; while(true) { // 找到最小的头部 int min_head = std::numeric_limits::max(); #pragma unroll for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i) if(!forward_channel_retired[i]) min_head = min(min_head, forward_channel_head[i][target_rdma]); if(__all_sync(kFullWarpMask, min_head == std::numeric_limits::max())) { break; } // 更新远程头部 if(min_head != std::numeric_limits::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){ #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_ulong_atomic_add(ctx, #else shmem_signal_op_add( #endif rdma_channel_head.buffer(rdma_rank), min_head - last_head, translate_dst_rdma_rank(lane_id, nvl_rank)); last_head = min_head; } // 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64); } } else if(warp_role == WarpRole::kNVLReceivers) { if(warp_id >= NUM_MAX_NVL_PEERS) { return; } // Place the main logic of your kernel here, using the parameters above. // NVL消费者 // 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩) int src_nvl_rank = target_rank, total_offset = 0; EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量"); if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; // 接收通道偏移 int start_offset = 0, end_offset = 0, num_tokens_to_recv; auto start_time = wall_clock64(); while(lane_id < kNumRDMARanks) { start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); if(start_offset < 0 && end_offset < 0) { start_offset = -start_offset - 1, end_offset = -end_offset - 1; total_offset += start_offset; break; } // 超时检查 if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); trap(); } } num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); // 保存以供合并使用 if(lane_id < kNumRDMARanks && !kCachedMode) recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; syncwarp(); int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; while(num_tokens_to_recv > 0) { // 通过通道0检查通道状态 start_time = wall_clock64(); while(lane_id == 0) { // 准备复制 if(cached_channel_head_idx != cached_channel_tail_idx) break; cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); // 超时检查 if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); trap(); } } // 同步队列尾部 cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0); // 复制数据 int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) { int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens; auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); int64_t recv_token_idx = shfl_sync(total_offset, meta.src_rdma_rank); (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; // 复制数据 UNROLLED_WARP_COPY(5, lane_id, hidden_int4, recv_x + recv_token_idx * hidden_int4, nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, ld_nc_global, st_na_global); // 复制源元数据 if(lane_id == 0 && !kCachedMode) st_na_global(recv_src_meta + recv_token_idx, meta); // 复制比例 UNROLLED_WARP_COPY(1, lane_id, num_scales, recv_x_scales + recv_token_idx * num_scales, nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, ld_nc_global, st_na_global); // 复制 `topk_idx` 和 `topk_weights` if(lane_id < num_topk) { auto recv_idx = recv_token_idx * num_topk + lane_id; auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; st_na_global(recv_topk_idx + recv_idx, static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); } } // 移动队列 syncwarp(); if(lane_id == 0) { st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); } } // while(num_tokens_to_recv > 0) } #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_wg_ctx_destroy(&ctx); #endif } void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights, void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head, int *send_nvl_head, int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels, bool low_latency_mode) { constexpr int kNumDispatchRDMASenderWarps = 7; // Make sure never OOB EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < std::numeric_limits::max()); #define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ { \ auto dispatch_func = \ low_latency_mode \ ? (is_cached_dispatch \ ? dispatch \ : dispatch) \ : (is_cached_dispatch \ ? dispatch \ : dispatch); \ LAUNCH_KERNEL_NON_COOPERATIVE( \ &cfg, dispatch_func, reinterpret_cast(recv_x), recv_x_scales, recv_topk_idx, \ recv_topk_weights, reinterpret_cast(recv_src_meta), \ reinterpret_cast(x), x_scales, topk_idx, topk_weights, send_rdma_head, \ send_nvl_head, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \ rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, num_scales, \ num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, \ num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \ num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \ } \ break EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream); SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE } template __global__ void __launch_bounds__(1024, 1) cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens, int num_channels, const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks, bool is_cached_dispatch, const shmem_team_t rdma_team) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); auto num_threads = static_cast(blockDim.x); auto num_warps = num_threads / kWarpSize; auto warp_id = thread_id / kWarpSize; auto lane_id = get_lane_id(); auto nvl_rank = rank % NUM_MAX_NVL_PEERS; auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Using two SMs, which clean the RDMA/NVL buffer respectively if (sm_id == 0) { // Barrier for RDMA if (thread_id == kWarpSize) dushmem_barrier_with_same_gpu_idx(rdma_team); // Barrier for NVL barrier_block(barrier_signal_ptrs, nvl_rank); // Clean RDMA buffer auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; // Clean NVL buffer auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; __syncthreads(); // Barrier again if (thread_id == kWarpSize) dushmem_barrier_with_same_gpu_idx(rdma_team); // Barrier again barrier_block(barrier_signal_ptrs, nvl_rank); } else if (sm_id == 1) { if (is_cached_dispatch) return; EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize); // Iterate in reverse order for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { if (lane_id < num_rdma_ranks) { int token_start_idx, token_end_idx; get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // NOTES: `1 << 25` is a heuristic large number int last_head = 1 << 25; for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); if (current_head < 0) { combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; } else { last_head = current_head; } } } } } else { if (is_cached_dispatch) return; EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers"); constexpr int num_clean_sms = 2; for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { if (lane_id < NUM_MAX_NVL_PEERS ) { for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL - num_clean_sms) { // Iterate in reverse order int token_start_idx = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; token_start_idx += shift, token_end_idx += shift; // NOTES: `1 << 25` is a heuristic large number int last_head = 1 << 25; for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); if (current_head < 0) { combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; } else { last_head = current_head; } } } } } } } void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_ranks, int num_channels, int num_combined_tokens, int *combined_rdma_head, const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank, hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode) { const int num_threads = ::min(1024, ::max(128, kWarpSize * num_channels)); const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Get clean meta auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL > 2); // Launch kernel auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, num_threads, stream); LAUNCH_KERNEL_NON_COOPERATIVE( &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch, cpu_rdma_team); } template __device__ int combine_token(bool is_token_in_rank, int head_idx, int lane_id, int hidden_int4, int num_topk, int4* combined_row, float* combined_topk_weights, int num_max_recv_tokens, const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn) { constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); // Broadcast current heads // Lane `i` holds the head of rank `i` and `is_token_in_rank` EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks"); int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; #pragma unroll for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) { slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens; topk_ranks[num_topk_ranks ++] = i; } EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); // Reduce data #pragma unroll for (int i = lane_id; i < hidden_int4; i += kWarpSize) { // Read buffers float values[kDtypePerInt4] = {0}; // 8 × 4B = 32B #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) { int4 recv_value = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i)); auto recv_dtypes = reinterpret_cast(&recv_value); #pragma unroll for (int k = 0; k < kDtypePerInt4; ++k) values[k] += static_cast(recv_dtypes[k]); } // Cast back to `dtype_t` and write int4 out_int4; auto out_dtypes = reinterpret_cast(&out_int4); #pragma unroll for (int j = 0; j < kDtypePerInt4; ++ j) out_dtypes[j] = static_cast(values[j]); st_na_global(combined_row + i, out_int4); } // Reduce `topk_weights` if (lane_id < num_topk) { float value = 0; #pragma unroll for (int i = 0; i < num_topk_ranks; ++ i) value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); st_na_global(combined_topk_weights + lane_id, value); } // Return the minimum top-k rank return topk_ranks[0]; } template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, int kNumRDMAReceivers = kNumForwarders> __global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank, const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1, const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta, const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks) { enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kRDMACoordinator, kNVLCoordinator }; #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) __shared__ shmem_ctx_t ctx; shmem_wg_ctx_create(&ctx); #endif EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps"); const auto sm_id = static_cast(blockIdx.x); const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / kWarpSize; const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto num_channels = static_cast(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); // NOTES: we decouple a channel into 2 SMs const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; const auto role_meta = [=]() -> std::pair { if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) { return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) { if(warp_id < kNumForwarders) { return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders}; } else { return {WarpRole::kRDMACoordinator, 0}; } } else { if(warp_id < kNumForwarders) { return {WarpRole::kRDMAReceiver, warp_id}; } else { return {WarpRole::kNVLCoordinator, 0}; } } }(); auto warp_role = role_meta.first; auto target_rank = role_meta.second; // Not applicable for RDMA senders EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1); auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; // This approach is designed to sync multiple warps in a loop constexpr int num_sync_large_iteration = 64; constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration; __shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters]; for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) { sync_large_warp_counters[i] = 0; } __syncthreads(); if (warp_role == WarpRole::kNVLSender) { if(warp_id >= NUM_MAX_NVL_PEERS) { return; } const auto dst_nvl_rank = target_rank; // NVL layouts // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); auto nvl_channel_src_meta = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); auto nvl_channel_topk_weights = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); auto nvl_channel_head = AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr); auto nvl_channel_tail = AsymBuffer(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); // Get tasks for each RDMA lane int token_start_idx = 0, token_end_idx = 0; if(lane_id < kNumRDMARanks) { int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; } syncwarp(); // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers"); // Iterate over all tokens and send by chunks while(true) { // Exit if possible if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx)) break; // Decide next RDMA buffer to send bool is_lane_ready = false; auto start_time = wall_clock64(); while(true) { int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; if(__any_sync(kFullWarpMask, is_lane_ready)) break; // Retry if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx) cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); // Timeout check if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, " "RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n", channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx, token_start_idx, token_end_idx); trap(); } } // Sync token start index and count for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) { if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) continue; // Sync token start index auto token_idx = static_cast(shfl_sync(token_start_idx, current_rdma_idx)); int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); // Send by chunk for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) { // Get an empty slot int dst_slot_idx = 0; if(lane_id == current_rdma_idx) { dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma; dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; } dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx); // Copy data auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4; UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); // Copy source meta if(lane_id == 0) st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx)); // Copy `topk_weights` if(lane_id < num_topk) st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id)); } lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) : 0; } // Move queue tail syncwarp(); if(lane_id < kNumRDMARanks and is_lane_ready) { st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); } } } else { if(warp_id > kNumForwarders) { return; } // Combiners and coordinators // RDMA symmetric layout auto hidden_bytes = hidden_int4 * sizeof(int4); auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk); auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); // NVL layouts void* local_nvl_buffer = buffer_ptrs[nvl_rank]; void* nvl_buffers[NUM_MAX_NVL_PEERS]; #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) nvl_buffers[i] = buffer_ptrs[i]; auto nvl_channel_x = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); auto nvl_channel_src_meta = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); auto nvl_channel_topk_weights = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); auto nvl_channel_head = AsymBuffer(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer); auto nvl_channel_tail = AsymBuffer(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); // Combiner warp synchronization __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; __shared__ volatile bool forwarder_retired[kNumForwarders]; __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; if (warp_role == WarpRole::kNVLAndRDMAForwarder) { // Receive from NVL ranks and forward to RDMA ranks // NOTES: this part is using "large warps" for each RDMA ranks const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder; const auto sub_warp_id = target_rank % kNumWarpsPerForwarder; auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank); // auto sync_large_warp = [=]() { // if(kNumWarpsPerForwarder == 1) { // syncwarp(); // } else { // // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize)); // // __syncthreads(); // syncwarp(); // } // }; auto sync_large_warp = [=](const int iter, const int mode) { if (kNumWarpsPerForwarder == 1) { syncwarp(); } else { // LDS index to store for sync int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters; //reset index in the LDS to avoid race condition due to warp scheduling int reset_idx = dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters; auto start_time = wall_clock64(); if (lane_id == 0){ volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1); } syncwarp(); //The while(...) loop polls the counter until all warps have arrived if (lane_id == 0){ while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){ if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration ); trap(); } } } syncwarp(); if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){ sync_large_warp_counters[reset_idx] = 0; } syncwarp(); } }; EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough"); // Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移 nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4); nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma); nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk); nvl_channel_head.advance(dst_rdma_rank); nvl_channel_tail.advance(dst_rdma_rank); // Clean shared memory and sync EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers"); lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0; lane_id == 0 ? (forwarder_retired[target_rank] = false) : false; // sync_forwarder_smem(); __syncthreads(); // Get count and cached head int cached_nvl_channel_tail_idx = 0; int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; num_tokens_to_combine -= num_tokens_prefix; num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; // Iterate over all tokens and combine by chunks for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) { // Check destination queue emptiness, or wait a buffer to be released auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); auto num_chunked_tokens = token_end_idx - token_start_idx; auto start_time = wall_clock64(); while(sub_warp_id == 0 and lane_id == 0) { // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` // Here, `token_start_idx` is the actual tail int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) break; // Timeout check if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n", channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); trap(); } } // sync_large_warp(); sync_large_warp(token_start_idx, 0); // Combine and write to the RDMA buffer for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) { // Read expected head EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers"); int expected_head = -1; if(lane_id < NUM_MAX_NVL_PEERS) expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); // Wait lanes to be ready start_time = wall_clock64(); while(cached_nvl_channel_tail_idx <= expected_head) { cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); // Timeout check if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head); trap(); } } // Combine current token auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4) + hidden_int4_idx; }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; combine_token(expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, reinterpret_cast(shifted), reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), num_max_nvl_chunked_recv_tokens_per_rdma, get_addr_fn, recv_tw_fn); // Update head if(lane_id < NUM_MAX_NVL_PEERS) { expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1) : (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1); } } // sync_large_warp(); sync_large_warp(token_start_idx, 1); // Issue RDMA send if(sub_warp_id == kNumWarpsPerForwarder - 1) { if(dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_schar_put_nbi_warp(ctx, #else shmemx_int8_put_nbi_warp( #endif rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, num_chunked_tokens * num_bytes_per_rdma_token, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_quiet(ctx); #else shmem_fence(); #endif } else { memory_fence(); } // Write new RDMA tail syncwarp(); if(lane_id == 0) { #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_ulong_atomic_add(ctx, #else shmem_signal_op_add( #endif rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); } } } // Retired syncwarp(); if(lane_id == 0) { forwarder_retired[target_rank] = true; } } else if (warp_role == WarpRole::kRDMACoordinator) { // Coordinator // Sync shared memory status // sync_forwarder_smem(); __syncthreads(); constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; int last_nvl_head[kNumRDMARanks] = {0}; int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; while(true) { // Retired if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id])) break; { // Find minimum head for NVL ranks #pragma unroll for(int i = 0; i < kNumRDMARanks; ++i) { int min_head = std::numeric_limits::max(); #pragma unroll for(int j = 0; j < num_warps_per_rdma_rank; ++j) if(not forwarder_retired[i * num_warps_per_rdma_rank + j]) min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]); if(min_head != std::numeric_limits::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) { st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head); } } } // Nanosleep and let other warps work __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64); } } else if(warp_role == WarpRole::kRDMAReceiver) { // Receive from RDMA ranks and write to the output tensor // Clean shared memory and sync EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize); lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0; lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0; // sync_rdma_receiver_smem(); __syncthreads(); // The same tokens as the dispatch process int token_start_idx, token_end_idx; get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // ==================== Token 级展开 x4 ==================== constexpr int kTokenUnroll = 4; int cached_channel_tail_idx = 0; for (int64_t base = token_start_idx + target_rank; base < token_end_idx; base += (int64_t)kNumRDMAReceivers * kTokenUnroll) { // ---- Phase 1: 批量预取所有 token 的 expected_head ---- int cached_expected_head[kTokenUnroll]; int max_expected_head = -1; #pragma unroll for (int u = 0; u < kTokenUnroll; ++u) { int64_t tidx = base + (int64_t)u * kNumRDMAReceivers; cached_expected_head[u] = -1; if (tidx < token_end_idx && lane_id < kNumRDMARanks) { int expected_head = ld_nc_global(combined_rdma_head + tidx * kNumRDMARanks + lane_id); cached_expected_head[u] = expected_head; if (expected_head > max_expected_head) max_expected_head = expected_head; } } // ---- Phase 2: 一次等待,覆盖所有 token ---- if (max_expected_head >= 0) { auto start_time = wall_clock64(); while (cached_channel_tail_idx <= max_expected_head) { cached_channel_tail_idx = static_cast( ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP combine RDMA receiver timeout (unroll x%d), " "ch: %d, rdma: %d, nvl: %d, lane: %d, " "tail: %d, wait: %d\n", kTokenUnroll, channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, max_expected_head); trap(); } } } syncwarp(); // ---- Phase 3: 批量处理所有就绪 token ---- #pragma unroll for (int u = 0; u < kTokenUnroll; ++u) { int64_t tidx = base + (int64_t)u * kNumRDMAReceivers; if (tidx < token_end_idx) { int expected_head = cached_expected_head[u]; // Combine current token auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx; }; auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; combine_token( expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, combined_x + tidx * hidden_int4, combined_topk_weights + tidx * num_topk, num_max_rdma_chunked_recv_tokens, get_addr_fn, recv_tw_fn); if (lane_id < kNumRDMARanks) { rdma_receiver_rdma_head[target_rank][lane_id] = expected_head < 0 ? -expected_head - 1 : expected_head; } } } } // Retired syncwarp(); if (lane_id == 0) { rdma_receiver_retired[target_rank] = true; } } else if(warp_role == WarpRole::kNVLCoordinator) { // Coordinator // Sync shared memory status // sync_rdma_receiver_smem(); __syncthreads(); const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; int last_rdma_head = 0; int last_nvl_head[kNumRDMARanks] = {0}; int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; while(true) { // Retired if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) break; // Find minimum head for RDMA ranks { int min_head = std::numeric_limits::max(); #pragma unroll for(int i = 0; i < kNumRDMAReceivers; ++i) if(not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_ctx_ulong_atomic_add(ctx, #else shmem_signal_op_add( #endif rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); last_rdma_head = min_head; } } // Nanosleep and let other warps work __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64); } } } #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) shmem_wg_ctx_destroy(&ctx); #endif } void combine(hipDataType type, void *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank, const void *x, const float *topk_weights, const void *bias_0, const void *bias_1, const int *combined_rdma_head, const int *combined_nvl_head, const void *src_meta, const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) { constexpr int kNumCombineForwarderWarps = 8; #define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ { \ auto combine_func = \ low_latency_mode \ ? combine \ : combine; \ LAUNCH_KERNEL_NON_COOPERATIVE( \ &cfg, combine_func, reinterpret_cast(combined_x), combined_topk_weights, \ is_combined_token_in_rank, reinterpret_cast(x), topk_weights, \ reinterpret_cast(bias_0), reinterpret_cast(bias_1), \ combined_rdma_head, combined_nvl_head, reinterpret_cast(src_meta), \ rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, \ num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \ num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \ } \ break int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS); EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); EP_HOST_ASSERT(type == HIP_R_16BF); SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream); SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); #undef COMBINE_LAUNCH_CASE } } // namespace internode } // namespace deep_ep // #ifdef __clang__ // #pragma clang diagnostic pop // #endif // __clang__ #endif