Commit 1a35d640 authored by root's avatar root
Browse files

fix dtk26.04 4nodes core dump.


Signed-off-by: default avatarroot <root@host-10-212-17-3.cluster.local>
parent 95e46992
...@@ -47,7 +47,7 @@ struct Config { ...@@ -47,7 +47,7 @@ struct Config {
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0); EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_ranks <=8 ? num_sms / 2 : num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const int num_channels = num_sms / 2;
// 计算每个nvl通信数据包的数据量 // 计算每个nvl通信数据包的数据量
size_t num_single_nvl_bag_bytes = size_t num_single_nvl_bag_bytes =
...@@ -83,7 +83,7 @@ struct Config { ...@@ -83,7 +83,7 @@ struct Config {
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0); EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const int num_channels = num_sms / 2;
// 计算每个rdma通信数据包的数据量 // 计算每个rdma通信数据包的数据量
size_t num_single_rdma_bag_bytes = size_t num_single_rdma_bag_bytes =
......
...@@ -809,8 +809,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te ...@@ -809,8 +809,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// here. // here.
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
const int num_channels = config.num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0); // EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);
bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();
...@@ -1130,8 +1130,8 @@ Buffer::internode_combine( ...@@ -1130,8 +1130,8 @@ Buffer::internode_combine(
const torch::Tensor &combined_nvl_head, const Config &config, const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream) { std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream) {
#ifndef DISABLE_ROCSHMEM #ifndef DISABLE_ROCSHMEM
const int num_channels = config.num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0); // EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
// Shape and contiguous checks // Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
......
...@@ -7,17 +7,6 @@ ...@@ -7,17 +7,6 @@
#ifndef DISABLE_ROCSHMEM #ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
namespace deep_ep { namespace deep_ep {
namespace internode { namespace internode {
...@@ -25,7 +14,7 @@ namespace internode { ...@@ -25,7 +14,7 @@ namespace internode {
extern shmem_team_t cpu_rdma_team; extern shmem_team_t cpu_rdma_team;
struct SourceMeta { struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits; // sizeof(SourceMeta) = 8 int src_rdma_rank, is_token_in_nvl_rank_bits;
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");
...@@ -60,18 +49,16 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_ ...@@ -60,18 +49,16 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__ __device__ __forceinline__ std::pair<int, int> __host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_channels) { int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
// Return `int32_t` offset and count to clean // Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int), num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels}; (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
} }
__host__ __device__ __forceinline__ std::pair<int, int> __host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
int num_channels) { int num_sms) {
// Return `int32_t` offset and to clean // Return `int32_t` offset and to clean
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
"Invalid size of `SourceMeta`"); "Invalid size of `SourceMeta`");
...@@ -79,8 +66,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -79,8 +66,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(num_nvl_recv_buffer_tokens * (num_nvl_recv_buffer_tokens *
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_channels) / sizeof(int), num_nvl_ranks * num_sms) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
}; };
} }
...@@ -92,9 +79,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, ...@@ -92,9 +79,10 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template <bool kLowLatencyMode> template <bool kLowLatencyMode>
__forceinline__ __device__ void __forceinline__ __device__ void
dushmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) { shmem_sync_with_same_gpu_idx(const shmem_team_t &rdma_team) {
// NOTE: shmem_device_barrier_all() might be an issue as // NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm // it doesn't follow OpenSHMEM specification on ROCm
// kLowLatencyMode ? shmem_device_sync(rdma_team) : shmem_device_sync_all();
kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all(); kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all();
} }
...@@ -123,7 +111,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in ...@@ -123,7 +111,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Communication with others // Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync // Global barrier: the first warp do intra-node sync, the second warp do internode sync
if (thread_id == kWarpSize) if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team); shmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
...@@ -175,7 +163,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in ...@@ -175,7 +163,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
__syncthreads(); __syncthreads();
if (thread_id == 0) if (thread_id == 0)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team); shmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads(); __syncthreads();
...@@ -266,7 +254,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in ...@@ -266,7 +254,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Finally barrier // Finally barrier
if (thread_id == kWarpSize) if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team); shmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else { } else {
...@@ -383,759 +371,754 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { ...@@ -383,759 +371,754 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
return num_rdma_ranks < 8 ? num_rdma_ranks : 8; return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
} }
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode,
template <bool kLowLatencyMode,
int kNumRDMARanks,
bool kCachedMode,
int kNumDispatchRDMASenderWarps, int kNumDispatchRDMASenderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)> int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
__global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1) __global__ void
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights, __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
SourceMeta *recv_src_meta, const int4 *x, const float *x_scales, dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head, SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
int *send_nvl_head, int *recv_rdma_channel_prefix_matrix, const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix, int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens, const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_ranks) { int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks) {
enum class WarpRole { enum class WarpRole {
kRDMASender, // 从x写入到RDMA发送缓存 kRDMASender,
kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存 kRDMASenderCoordinator,
kRDMAAndNVLForwarder, // 从RDMA接收缓存转写到ipc nvl缓存 kRDMAAndNVLForwarder,
kForwarderCoordinator, // 向远端RDMA确认接收 kForwarderCoordinator,
kNVLReceivers // 从nvl缓存写入到recv_x kNVLReceivers
}; };
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx; __shared__ rocshmem::rocshmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx); rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize; const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, lane_id = get_lane_id();
channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);
const auto role_meta = [=]() -> std::pair<WarpRole, int> { const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) { if (is_forwarder) {
if(warp_id < kNumDispatchRDMASenderWarps) { if (warp_id < NUM_MAX_NVL_PEERS) {
return {WarpRole::kRDMASender, -1};
} else if(warp_id == kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASenderCoordinator, -1};
}
} else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
if(warp_id < NUM_MAX_NVL_PEERS) {
return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
} else { } else {
return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
} }
} else if (warp_id < kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASender, -1};
} else if (warp_id == kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASenderCoordinator, -1};
} else { } else {
return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS}; return {WarpRole::kNVLReceivers,
(warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};
} }
}(); }();
auto warp_role = role_meta.first; auto warp_role = role_meta.first;
auto target_rank = role_meta.second; // Not applicable for RDMA senders auto target_rank = role_meta.second; // Not applicable for RDMA senders
EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS);
// Data checks
EP_DEVICE_ASSERT(num_topk <= kWarpSize);
// RDMA symmetric layout // RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t),
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); "Invalid number of NVL peers");
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); auto hidden_bytes = hidden_int4 * sizeof(int4);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); auto num_bytes_per_rdma_token =
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_data = SymBuffer<int8_t>(
rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks,
channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2,
kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL buffer layouts // NVL buffer layouts
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr`
// means "Write for Senders, Read for Receivers"
void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
int rs_wr_rank = 0, ws_rr_rank = 0; int rs_wr_rank = 0, ws_rr_rank = 0;
if (warp_role == WarpRole::kRDMAAndNVLForwarder) if (warp_role == WarpRole::kRDMAAndNVLForwarder)
rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank],
rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
if (warp_role == WarpRole::kNVLReceivers) if (warp_role == WarpRole::kNVLReceivers)
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank],
rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
// Allocate buffers // Allocate buffers
auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_x =
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4,
auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)
auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); .advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_src_meta =
auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS,
auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); channel_id, num_channels, rs_wr_rank)
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); .advance_also(rs_wr_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_x_scales =
AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales,
NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_idx =
AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk,
NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_weights =
AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk,
NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_start =
AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id,
num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_end =
AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id,
num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id,
num_channels, ws_rr_rank)
.advance_also(ws_rr_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id,
num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
// RDMA sender warp synchronization // RDMA sender warp synchronization
__shared__ volatile int rdma_send_next_token_idx; __shared__ volatile int rdma_send_next_token_idx;
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
__shared__ volatile int rdma_sender_counter[1];
__shared__ volatile int rdma_forwarder_counter[1];
if (threadIdx.x == 0) {
rdma_sender_counter[0] = 0;
rdma_forwarder_counter[0] = 0;
}
__syncthreads();
// NVL and RDMA coordinate Forward warp synchronization auto sync_rdma_sender_smem = [&]() {
__shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; if (lane_id == 0) {
__shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; volatile int ret = __hip_atomic_fetch_add(&rdma_sender_counter[0], 1, __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_WORKGROUP);
// Place the main logic of your kernel here, using the parameters above. // volatile int ret = atomicAdd((int*)&rdma_sender_counter[0], 1);
if(warp_role == WarpRole::kRDMASender) { }
/* syncwarp();
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。 while (rdma_sender_counter[0] < (kNumDispatchRDMASenderWarps + 1)) {
它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。 }
然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。 };
最后,它复制相关的数据到对称发送缓冲区。
kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
*/
// 获取任务范围
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// 清理共享内存 // Forward warp synchronization
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量"); __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
if(warp_id == 0 && lane_id == 0) { __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
rdma_send_next_token_idx = token_start_idx; // NOTE: Not sure that __syncthreads() is a suitable replacement
auto sync_forwarder_smem = [&]() {
if (lane_id == 0) {
volatile int ret = __hip_atomic_fetch_add(
&rdma_forwarder_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
// volatile int ret = atomicAdd((int*)&rdma_forwarder_counter[0], 1);
} }
if(warp_id == 0 && lane_id < kNumRDMARanks) { syncwarp();
rdma_send_channel_tail[lane_id] = 0; while (rdma_forwarder_counter[0] < (NUM_MAX_NVL_PEERS + 1)) {
rdma_send_channel_next_tail[lane_id] = 0;
} }
};
// 发送本通道中的令牌数量,通过 `-value - 1` 表示 if (warp_role == WarpRole::kRDMASender) {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量"); // Get tasks
// 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中 int token_start_idx, token_end_idx;
// 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { token_end_idx);
auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA ranks");
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize,
"Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks;
dst_rdma_rank += kNumDispatchRDMASenderWarps) {
if (lane_id < NUM_MAX_NVL_PEERS) { if (lane_id < NUM_MAX_NVL_PEERS) {
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] =
-(channel_id == 0
? 0
: gbl_channel_prefix_matrix
[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels +
channel_id - 1]) -
1;
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) { } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] =
-gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id -
NUM_MAX_NVL_PEERS) *
num_channels +
channel_id] -
1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) { } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] =
-(channel_id == 0 ? 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels +
channel_id - 1]) -
1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] =
-rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
} }
syncwarp(); rocshmem::rocshmem_ctx_int_put_nbi_wave(
if (dst_rdma_rank != rdma_rank) { ctx, rdma_channel_meta.recv_buffer(rdma_rank),
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp(ctx,
#else
shmemx_int_put_nbi_warp(
#endif
rdma_channel_meta.recv_buffer(rdma_rank),
rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2, rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
} }
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) rocshmem::rocshmem_ctx_quiet(ctx);
shmem_ctx_quiet(ctx); sync_rdma_sender_smem();
#else
shmem_fence();
#endif
// sync_rdma_sender_smem(); // Iterate over tokens and copy into buffer
__syncthreads();
// 遍历令牌并复制到缓冲区
int64_t token_idx; int64_t token_idx;
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id)
for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { : rdma_channel_data.send_buffer(lane_id);
// 读取RDMA秩的存在性 for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx;
token_idx += kNumDispatchRDMASenderWarps) {
// Read RDMA rank existence
uint64_t is_token_in_rank_uint64 = 0; uint64_t is_token_in_rank_uint64 = 0;
if(lane_id < kNumRDMARanks) { if (lane_id < kNumRDMARanks)
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>(
} is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
// 获得处理数据的自旋锁,获得锁后才会处理一些数据信息 // Acquire sequential lock
while(lane_id == 0 && rdma_send_next_token_idx != token_idx) { while (lane_id == 0 and rdma_send_next_token_idx != token_idx)
// 等待 ;
}
syncwarp(); syncwarp();
// 获取下一个尾部位置 // Acquire next tail
int rdma_tail_idx = -1; int rdma_tail_idx = -1;
if(is_token_in_rank_uint64 != 0) { if (is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++; rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
// 与kForwarderCoordinator相互配合,调节发送数据的频率 cached_rdma_channel_head =
while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
}
} }
syncwarp(); syncwarp();
// 存储RDMA头部以供合并 // Store RDMA head for combine
if(lane_id < kNumRDMARanks && !kCachedMode) { if (lane_id < kNumRDMARanks and not kCachedMode)
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
}
// 更新最后一个令牌尾部 // Update last token tail
if(last_rdma_tail_idx >= 0) { if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id),
} last_rdma_tail_idx + 1);
last_rdma_tail_idx = rdma_tail_idx; last_rdma_tail_idx = rdma_tail_idx;
// 释放顺序锁 // Release sequential lock
if(lane_id == 0) { lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
rdma_send_next_token_idx += 1;
}
// 广播尾部位置 // Broadcast tails
SourceMeta src_meta; SourceMeta src_meta;
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
void* dst_send_buffers[kNumTopkRDMARanks]; void *dst_send_buffers[kNumTopkRDMARanks];
/* #pragma unroll
该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作 for (int i = 0, slot_idx; i < kNumRDMARanks; ++i)
*/ if ((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
#pragma unroll slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;
for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) { topk_ranks[num_topk_ranks] = i;
// 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
// warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
// 计算slot_idx在接收缓冲区中的位置
slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;
// 存储当前RDMA秩到topk_ranks数组中
topk_ranks[num_topk_ranks] = i;
// 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64); auto recv_is_token_in_rank_values =
reinterpret_cast<const bool *>(&recv_is_token_in_rank_uint64);
// 如果当前lane_id等于num_topk_ranks,则更新src_meta if (lane_id == num_topk_ranks)
if(lane_id == num_topk_ranks) {
src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
} dst_send_buffers[num_topk_ranks++] =
reinterpret_cast<uint8_t *>(broadcast(send_buffer, i)) +
// 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中 slot_idx * num_bytes_per_rdma_token;
// 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
} }
}
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
//////////////// 复制数据到发送缓冲区 //////////////// // Copy `x` into symmetric send buffer
// 复制源元数据到对称发送缓冲区 auto st_broadcast = [=](const int key, const int4 &value) {
if(lane_id < num_topk_ranks) { for (int j = 0; j < num_topk_ranks; ++j)
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta); st_na_global(reinterpret_cast<int4 *>(dst_send_buffers[j]) + key, value);
} };
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4,
ld_nc_global, st_broadcast);
for (int i = 0; i < num_topk_ranks; ++i)
dst_send_buffers[i] = reinterpret_cast<int4 *>(dst_send_buffers[i]) + hidden_int4;
for(int i = 0; i < num_topk_ranks; ++i) { // Copy source metadata into symmetric send buffer
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1; if (lane_id < num_topk_ranks)
} st_na_global(reinterpret_cast<SourceMeta *>(dst_send_buffers[lane_id]), src_meta);
// 复制 `x` 到对称发送缓冲区 for (int i = 0; i < num_topk_ranks; ++i)
auto st_broadcast = [=](const int key, const int4& value) { dst_send_buffers[i] = reinterpret_cast<SourceMeta *>(dst_send_buffers[i]) + 1;
for(int j = 0; j < num_topk_ranks; ++j) {
st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
}
};
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
}
// 复制 `x_scales` 到对称发送缓冲区 // Copy `x_scales` into symmetric send buffer
for(int i = lane_id; i < num_scales; i += kWarpSize) { for (int i = lane_id; i < num_scales; i += kWarpSize) {
auto value = ld_nc_global(x_scales + token_idx * num_scales + i); auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
for(int j = 0; j < num_topk_ranks; ++j) { for (int j = 0; j < num_topk_ranks; ++j)
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value); st_na_global(reinterpret_cast<float *>(dst_send_buffers[j]) + i, value);
}
}
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
} }
for (int i = 0; i < num_topk_ranks; ++i)
dst_send_buffers[i] = reinterpret_cast<float *>(dst_send_buffers[i]) + num_scales;
// 复制 `topk_idx` `topk_weights` 到对称发送缓冲区 // Copy `topk_idx` and `topk_weights` into symmetric send buffer
for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) { for (int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
auto rank_idx = i / num_topk, copy_idx = i % num_topk; auto rank_idx = i / num_topk, copy_idx = i % num_topk;
auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); auto idx_value =
static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value); st_na_global(reinterpret_cast<int *>(dst_send_buffers[rank_idx]) + copy_idx,
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); idx_value);
st_na_global(reinterpret_cast<float *>(dst_send_buffers[rank_idx]) + num_topk +
copy_idx,
weight_value);
} }
} }
// 结尾部分 // Epilogue
// 获取顺序锁 // Acquire sequential lock
while(lane_id == 0 && rdma_send_next_token_idx != token_idx) { while (lane_id == 0 and rdma_send_next_token_idx != token_idx)
// 等待 ;
}
syncwarp(); syncwarp();
// 更新最后一个令牌尾部 // Update last token tail
if(last_rdma_tail_idx >= 0) { if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id),
} last_rdma_tail_idx + 1);
// 释放顺序锁 // Release sequential lock
if(lane_id == 0) { lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
rdma_send_next_token_idx += 1; } else if (warp_role == WarpRole::kRDMASenderCoordinator) {
} // NOTES: in case of splitting the issued put at the end of the buffer
} else if(warp_role == WarpRole::kRDMASenderCoordinator) { EP_DEVICE_ASSERT(
/* num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。
kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
dushmem内存一致性(dushmem_fence)和原子操作(dushmemx_signal_op),减少硬同步,提升整体效率。
*/
if(warp_id > kNumDispatchRDMASenderWarps) {
return;
}
// 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
// 同步共享内存,确保所有线程在继续之前都达到了这一点 // Synchronize shared memory
// sync_rdma_sender_smem(); sync_rdma_sender_smem();
__syncthreads();
// 计算当前通道需要发送的令牌数 // Get number of tokens to send for each RDMA rank
int num_tokens_to_send = 0; int num_tokens_to_send = 0;
if(lane_id < kNumRDMARanks) { if (lane_id < kNumRDMARanks) {
num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
if(channel_id > 0) if (channel_id > 0)
num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; num_tokens_to_send -=
rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
} }
// 记录上次发出的尾部位置 // Iterate all RDMA ranks
int last_issued_tail = 0; int last_issued_tail = 0;
// 当有任何RDMA秩需要发送令牌时,继续循环 while (__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) { for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) { int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;
// 计算目标RDMA秩
int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;
// 获取同步后的需要发送的令牌数
synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank); synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
if (synced_num_tokens_to_send == 0)
continue;
if(synced_num_tokens_to_send == 0) // Read progress
continue; // 如果没有令牌需要发送,则跳过
// 读取进度
auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank); auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
auto processed_tail = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)); auto processed_tail =
auto num_tokens_processed = processed_tail - synced_last_issued_tail; ld_acquire_cta(const_cast<const int *>(rdma_send_channel_tail + dst_rdma_rank));
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
// 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过 if (num_tokens_processed != synced_num_tokens_to_send and
if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens) num_tokens_processed < num_max_rdma_chunked_send_tokens)
continue; continue;
// 计算本次需要发出的令牌数 // Issue RDMA send
auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); auto num_tokens_to_issue =
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send); min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and
// 发出RDMA发送请求 num_tokens_to_issue <= synced_num_tokens_to_send);
if(dst_rdma_rank != rdma_rank) { if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <=
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) num_max_rdma_chunked_recv_tokens);
shmem_ctx_schar_put_nbi_warp(ctx, rocshmem::rocshmem_ctx_schar_put_nbi_wave(
#else ctx,
shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) + rdma_channel_data.recv_buffer(rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token, dst_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_channel_data.send_buffer(dst_rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token, dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue, num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) rocshmem::rocshmem_ctx_quiet(ctx);
shmem_ctx_quiet(ctx);
#else
shmem_fence();
#endif
} else { } else {
// 对于本地RDMA秩,使用较轻的内存屏障 // Lighter fence for local RDMA rank
memory_fence(); memory_fence();
} }
// 更新尾部位置 // Update tails
syncwarp(); syncwarp();
if(lane_id == dst_rdma_rank) { if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue; last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) rocshmem::rocshmem_ctx_ulong_atomic_add(
shmem_ctx_ulong_atomic_add(ctx, ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
#else
shmem_signal_op_add(
#endif
rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
} }
} }
} // while(__any(num_tokens_to_send > 0)) }
} else if(warp_role == WarpRole::kRDMAAndNVLForwarder) { } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) {
/* // RDMA consumers and NVL producers
这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。 const auto dst_nvl_rank = target_rank;
它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。 const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;
接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。 const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);
然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。 const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks);
最后,它同步头部和尾部索引,并标记通道为退役状态。
*/ // Wait counters to arrive
// RDMA消费者和NVL生产者
const auto dst_nvl_rank = target_rank; // 目标NVL秩
const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; // 目标秩
const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); // 目标秩专家开始
const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束
// 等待计数器到达
int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize); EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
auto start_time = wall_clock64(); auto start_time = wall_clock64();
if(lane_id < kNumRDMARanks) { if (lane_id < kNumRDMARanks) {
while(true) { while (true) {
// 对应于kRDMASender中的数据写入 auto meta_0 =
auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); // 是nvl节点的起始地址 ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);
auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址 auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) +
auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); // 本rdma节点的起始地址 NUM_MAX_NVL_PEERS + dst_nvl_rank);
auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); // 本节点的结束地址 auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) +
if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) { NUM_MAX_NVL_PEERS * 2);
// 通知NVL秩 auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) +
NUM_MAX_NVL_PEERS * 2 + 1);
if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) {
// Notify NVL ranks
int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum); EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and
end_sum >= start_sum);
st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id,
-start_sum - 1);
st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);
// 保存从RDMA通道接收的令牌计数 // Save RDMA channel received token count
src_rdma_channel_prefix = -meta_2 - 1; src_rdma_channel_prefix = -meta_2 - 1;
auto src_rdma_channel_prefix_1 = -meta_3 - 1; auto src_rdma_channel_prefix_1 = -meta_3 - 1;
num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量 num_tokens_to_recv_from_rdma =
if(!kCachedMode) src_rdma_channel_prefix_1 - src_rdma_channel_prefix;
recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; if (not kCachedMode)
recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] =
src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中 src_rdma_channel_prefix_1;
src_rdma_channel_prefix +=
lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1];
EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
break; break;
} }
// 超时检查 // Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { long long int elapsed_time =
printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n", wall_clock64() > start_time ? wall_clock64() - start_time : 0;
channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); if (elapsed_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, "
"nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1,
meta_2, meta_3);
trap(); trap();
} }
} }
} }
syncwarp(); syncwarp();
// Shift cached head
// 移动缓存的头部
send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;
// 等待共享内存被清理 // Wait shared memory to be cleaned
// sync_forwarder_smem(); sync_forwarder_smem();
__syncthreads();
// 开始准备处理接受数据,直到所有的数据接受完成。 // Forward tokens from RDMA buffer
// 转发从RDMA缓冲区的令牌 // NOTES: always start from the local rank
// 注意:总是从本地秩开始 int src_rdma_rank = sm_id % kNumRDMARanks;
int src_rdma_rank = sm_id % kNumRDMARanks;
int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) { while (__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
// 检查nvl目标队列是否为空,或者等待一个缓冲区被释放 // Check destination queue emptiness, or wait a buffer to be released
start_time = wall_clock64(); start_time = wall_clock64();
while (lane_id == 0) {
// 用于给kNVLReceivers进行互动,控制数据的传输速度
while(lane_id == 0) {
int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) if (num_max_nvl_chunked_recv_tokens - num_used_slots >=
num_max_nvl_chunked_send_tokens)
break; break;
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
// 超时检查 // Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { long long int elapsed_time =
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", wall_clock64() > start_time ? wall_clock64() - start_time : 0;
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail); if (elapsed_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank,
ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
trap(); trap();
} }
} }
syncwarp(); syncwarp();
// 找到下一个源RDMA秩(轮询) // Find next source RDMA rank (round-robin)
start_time = wall_clock64(); start_time = wall_clock64();
while(true) { while (true) {
src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { if (shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail) if (lane_id == src_rdma_rank and
cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); cached_rdma_channel_head == cached_rdma_channel_tail)
cached_rdma_channel_tail = static_cast<int>(
if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) { ld_relaxed_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
if (shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head,
src_rdma_rank))
break; break;
}
} }
// 超时检查 // Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { long long int elapsed_time =
printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n", wall_clock64() > start_time ? wall_clock64() - start_time : 0;
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); if (elapsed_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: "
"%d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id,
cached_rdma_channel_head, cached_rdma_channel_tail,
num_tokens_to_recv_from_rdma);
trap(); trap();
} }
} }
auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank); auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank); auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
// 遍历RDMA缓冲区中的每一个令牌 // Iterate over every token from the RDMA buffer
for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
// 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入 void *shifted = rdma_channel_data.recv_buffer(src_rdma_rank) +
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted))); auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta *>(
if(lane_id == src_rdma_rank) { reinterpret_cast<int8_t *>(shifted) + hidden_bytes));
num_tokens_to_recv_from_rdma -= 1; lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
}
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
if(lane_id == src_rdma_rank) { if (lane_id == src_rdma_rank) {
auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
rdma_nvl_token_idx += is_in_dst_nvl_rank; rdma_nvl_token_idx += is_in_dst_nvl_rank;
if(!kCachedMode) if (not kCachedMode)
send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
} }
if (not is_in_dst_nvl_rank)
if(!is_in_dst_nvl_rank)
continue; continue;
// 获取一个空闲槽位 // Get an empty slot
int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens; int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
// 设置 src和dst 位置 // Copy data
auto src_gpu_buffer_x = reinterpret_cast<int4*>(reinterpret_cast<int8_t*>(shifted) + sizeof(SourceMeta));
auto src_gpu_buffer_scales = reinterpret_cast<float*>(reinterpret_cast<int8_t*>(src_gpu_buffer_x) + hidden_bytes);
auto src_gpu_buffer_topk_idx = reinterpret_cast<int*>(reinterpret_cast<int8_t*>(src_gpu_buffer_scales) + num_scales * sizeof(float));
auto src_gpu_buffer_topk_weights = reinterpret_cast<float*>(reinterpret_cast<int8_t*>(src_gpu_buffer_topk_idx) + num_topk * sizeof(int));
auto dst_gpu_buffer_x = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto dst_gpu_buffer_scales = nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales;
auto dst_gpu_buffer_topk_idx = nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk;
auto dst_gpu_buffer_topk_weights = nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk;
if(lane_id == 0) {
st_na_global(reinterpret_cast<int64_t*>(nvl_channel_src_meta.buffer() + dst_slot_idx),
*reinterpret_cast<int64_t*>(&src_meta));
}
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
dst_gpu_buffer_x, nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
src_gpu_buffer_x, reinterpret_cast<int4 *>(shifted), ld_nc_global, st_na_global);
ld_direct_global, st_na_global); shifted = reinterpret_cast<int4 *>(shifted) + hidden_int4;
// Copy source meta
if (lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
shifted = reinterpret_cast<SourceMeta *>(shifted) + 1;
// Copy `x_scales`
UNROLLED_WARP_COPY(1, lane_id, num_scales, UNROLLED_WARP_COPY(1, lane_id, num_scales,
dst_gpu_buffer_scales, nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
src_gpu_buffer_scales, reinterpret_cast<float *>(shifted), ld_nc_global, st_na_global);
ld_direct_global, st_na_global); shifted = reinterpret_cast<float *>(shifted) + num_scales;
for(int t = lane_id; t < num_topk; t += kWarpSize) { // Copy `topk_idx` and `topk_weights`
int idx_val = ld_direct_global(reinterpret_cast<int*>(src_gpu_buffer_topk_idx) + t); // NOTES: do not use `shifted` after this `if`, because only several lanes are
float w_val = ld_direct_global(reinterpret_cast<float*>(src_gpu_buffer_topk_weights) + t); // shifted
int new_idx = (idx_val >= dst_rank_expert_begin && idx_val < dst_rank_expert_end) if (lane_id < num_topk) {
? (idx_val - dst_rank_expert_begin) : -1; // Read
float new_w = (new_idx != -1) ? w_val : 0.0f; auto idx_value = ld_nc_global(reinterpret_cast<int *>(shifted) + lane_id);
dst_gpu_buffer_topk_idx[t] = new_idx; shifted = reinterpret_cast<int *>(shifted) + num_topk;
dst_gpu_buffer_topk_weights[t] = new_w; auto weight_value = ld_nc_global(reinterpret_cast<float *>(shifted) + lane_id);
// Transform and write
idx_value =
(idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end)
? idx_value - dst_rank_expert_begin
: -1;
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id,
idx_value);
weight_value = idx_value >= 0 ? weight_value : 0.0f;
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk +
lane_id,
weight_value);
} }
// 在NVL缓冲区不足的情况下,提前停止 // In case of insufficient NVL buffers, early stopping
if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens) if ((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
src_rdma_tail = i + 1; src_rdma_tail = i + 1;
} }
// 同步头部索引 // Sync head index
if(lane_id == src_rdma_rank) if (lane_id == src_rdma_rank)
forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); forward_channel_head[dst_nvl_rank][src_rdma_rank] =
(cached_rdma_channel_head = src_rdma_tail);
// 移动尾部索引,与kNVLReceivers互相通信使用 // Move tail index
syncwarp(); syncwarp();
if(lane_id == 0) { if (lane_id == 0)
st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); st_relaxed_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
}
} }
// Retired // Retired
syncwarp(); syncwarp();
if(lane_id == 0) { if (lane_id == 0)
forward_channel_retired[dst_nvl_rank] = true; forward_channel_retired[dst_nvl_rank] = true;
} } else if (warp_role == WarpRole::kForwarderCoordinator) {
} else if(warp_role == WarpRole::kForwarderCoordinator) {
/*
这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
然后,它清理共享内存,并初始化转发通道的头部和退役状态。
接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
*/
// Extra warps for forwarder coordinator should exit directly // Extra warps for forwarder coordinator should exit directly
if (warp_id > NUM_MAX_NVL_PEERS) if (target_rank > 0)
return; return;
// 转发warp协调器 // Forward warp coordinator
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量"); EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
// 清理共享内存
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量"); // Clean shared memory
#pragma unroll EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize) for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
if(lane_id < NUM_MAX_NVL_PEERS) if (lane_id < NUM_MAX_NVL_PEERS)
forward_channel_retired[lane_id] = false; forward_channel_retired[lane_id] = false;
// sync_forwarder_smem(); sync_forwarder_smem();
__syncthreads();
int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
while (true) {
while(true) { // Find minimum head
// 找到最小的头部
int min_head = std::numeric_limits<int>::max(); int min_head = std::numeric_limits<int>::max();
#pragma unroll #pragma unroll
for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i) for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
if(!forward_channel_retired[i]) if (not forward_channel_retired[i])
min_head = min(min_head, forward_channel_head[i][target_rdma]); min_head = min(min_head, forward_channel_head[i][target_rdma]);
if (__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max()))
if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
break; break;
}
// 更新远程头部 // Update remote head
if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){ if (min_head != std::numeric_limits<int>::max() and
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) min_head >= last_head + num_max_rdma_chunked_send_tokens and
shmem_ctx_ulong_atomic_add(ctx, lane_id < kNumRDMARanks) {
#else rocshmem::rocshmem_ctx_ulong_atomic_add(
shmem_signal_op_add( ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head,
#endif
rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head; last_head = min_head;
} }
// 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work // Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64); __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
} }
} else if(warp_role == WarpRole::kNVLReceivers) { } else {
if(warp_id >= NUM_MAX_NVL_PEERS) { // NVL consumers
return; // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)
}
// Place the main logic of your kernel here, using the parameters above.
// NVL消费者
// 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
int src_nvl_rank = target_rank, total_offset = 0; int src_nvl_rank = target_rank, total_offset = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量"); EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];
// 接收通道偏移 // Receive channel offsets
int start_offset = 0, end_offset = 0, num_tokens_to_recv; int start_offset = 0, end_offset = 0, num_tokens_to_recv;
auto start_time = wall_clock64(); auto start_time = wall_clock64();
while (lane_id < kNumRDMARanks) {
while(lane_id < kNumRDMARanks) {
start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
if(start_offset < 0 && end_offset < 0) { if (start_offset < 0 and end_offset < 0) {
start_offset = -start_offset - 1, end_offset = -end_offset - 1; start_offset = -start_offset - 1, end_offset = -end_offset - 1;
total_offset += start_offset; total_offset += start_offset;
break; break;
} }
// 超时检查
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { // Timeout check
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", long long int elapsed_time =
channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); wall_clock64() > start_time ? wall_clock64() - start_time : 0;
if (elapsed_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src "
"RDMA: %d, src nvl: %d, start: %d, end: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset,
end_offset);
trap(); trap();
} }
} }
num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);
// 保存以供合并使用 // Save for combine usage
if(lane_id < kNumRDMARanks && !kCachedMode) if (lane_id < kNumRDMARanks and not kCachedMode)
recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) *
num_channels +
channel_id] = total_offset;
syncwarp(); syncwarp();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while(num_tokens_to_recv > 0) { while (num_tokens_to_recv > 0) {
// 通过通道0检查通道状态 // Check channel status by lane 0
start_time = wall_clock64(); start_time = wall_clock64();
while(lane_id == 0) { while (lane_id == 0) {
// 准备复制 // Ready to copy
if(cached_channel_head_idx != cached_channel_tail_idx) if (cached_channel_head_idx != cached_channel_tail_idx)
break; break;
cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); cached_channel_tail_idx = ld_relaxed_sys_global(nvl_channel_tail.buffer());
// 超时检查
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { // Timeout check
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", long long int elapsed_time =
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); wall_clock64() > start_time ? wall_clock64() - start_time : 0;
if (elapsed_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, "
"src NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx,
cached_channel_tail_idx);
trap(); trap();
} }
} }
// 同步队列尾部 // Sync queue tail
cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0); cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);
// 复制数据 // Copy data
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) { for (int chunk_idx = 0; chunk_idx < num_recv_tokens;
int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); int token_idx_in_buffer =
int64_t recv_token_idx = shfl_sync(total_offset, meta.src_rdma_rank); (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
int64_t recv_token_idx = shfl_sync(total_offset, meta.src_rdma_rank);
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
// 复制数据 // Copy data
UNROLLED_WARP_COPY(5, UNROLLED_WARP_COPY(5, lane_id, hidden_int4, recv_x + recv_token_idx * hidden_int4,
lane_id, nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
hidden_int4, ld_nc_global, st_na_global);
recv_x + recv_token_idx * hidden_int4,
nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, // Copy source meta
ld_nc_global, if (lane_id == 0 and not kCachedMode)
st_na_global);
// 复制源元数据
if(lane_id == 0 && !kCachedMode)
st_na_global(recv_src_meta + recv_token_idx, meta); st_na_global(recv_src_meta + recv_token_idx, meta);
// 复制比例 // Copy scales
UNROLLED_WARP_COPY(1, UNROLLED_WARP_COPY(1, lane_id, num_scales,
lane_id, recv_x_scales + recv_token_idx * num_scales,
num_scales, nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
recv_x_scales + recv_token_idx * num_scales, ld_nc_global, st_na_global);
nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
ld_nc_global, // Copy `topk_idx` and `topk_weights`
st_na_global); if (lane_id < num_topk) {
// 复制 `topk_idx` 和 `topk_weights`
if(lane_id < num_topk) {
auto recv_idx = recv_token_idx * num_topk + lane_id; auto recv_idx = recv_token_idx * num_topk + lane_id;
auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); st_na_global(recv_topk_idx + recv_idx,
st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); static_cast<int64_t>(
ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
st_na_global(recv_topk_weights + recv_idx,
ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
} }
} }
// 移动队列 // Move queue
syncwarp(); syncwarp();
if(lane_id == 0) { if (lane_id == 0)
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
} }
} // while(num_tokens_to_recv > 0)
} }
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) rocshmem::rocshmem_wg_ctx_destroy(&ctx);
shmem_wg_ctx_destroy(&ctx);
#endif
} }
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights, void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
...@@ -1152,8 +1135,6 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float ...@@ -1152,8 +1135,6 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels, int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
bool low_latency_mode) { bool low_latency_mode) {
constexpr int kNumDispatchRDMASenderWarps = 7; constexpr int kNumDispatchRDMASenderWarps = 7;
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ #define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
{ \ { \
...@@ -1181,8 +1162,8 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float ...@@ -1181,8 +1162,8 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, SETUP_LAUNCH_CONFIG(num_channels * 2,
(1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream); (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
...@@ -1209,7 +1190,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1209,7 +1190,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (sm_id == 0) { if (sm_id == 0) {
// Barrier for RDMA // Barrier for RDMA
if (thread_id == kWarpSize) if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team); shmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
// Barrier for NVL // Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
...@@ -1228,7 +1209,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1228,7 +1209,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Barrier again // Barrier again
if (thread_id == kWarpSize) if (thread_id == kWarpSize)
dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team); shmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
// Barrier again // Barrier again
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank); barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
...@@ -1236,25 +1217,24 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1236,25 +1217,24 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (is_cached_dispatch) if (is_cached_dispatch)
return; return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize); EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
// Iterate in reverse order // Iterate in reverse order
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { if (lane_id < num_rdma_ranks and warp_id < num_channels) {
if (lane_id < num_rdma_ranks) { int token_start_idx, token_end_idx;
int token_start_idx, token_end_idx; get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
token_end_idx);
// NOTES: `1 << 25` is a heuristic large number // NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25; int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
auto current_head = auto current_head =
__ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
if (current_head < 0) { if (current_head < 0) {
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
} else { } else {
last_head = current_head; last_head = current_head;
}
} }
} }
} }
...@@ -1262,34 +1242,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1262,34 +1242,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (is_cached_dispatch) if (is_cached_dispatch)
return; return;
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
constexpr int num_clean_sms = 2; constexpr int num_clean_sms = 2;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
if (lane_id < NUM_MAX_NVL_PEERS ) { for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - num_clean_sms) {
dst_rdma_rank += num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL - num_clean_sms) { // Iterate in reverse order
// Iterate in reverse order int token_start_idx =
int token_start_idx = warp_id == 0
channel_id == 0 ? 0
? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; int token_end_idx =
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; token_start_idx += shift, token_end_idx += shift;
token_start_idx += shift, token_end_idx += shift;
// NOTES: `1 << 25` is a heuristic large number
// NOTES: `1 << 25` is a heuristic large number int last_head = 1 << 25;
int last_head = 1 << 25; for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { auto current_head =
auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
__ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); if (current_head < 0) {
if (current_head < 0) { combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; } else {
} else { last_head = current_head;
last_head = current_head;
}
} }
} }
} }
...@@ -1305,7 +1285,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1305,7 +1285,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank, int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) { bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = ::min(1024, ::max(128, kWarpSize * num_channels)); const int num_threads = ::max(128, kWarpSize * num_channels);
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta // Get clean meta
...@@ -1321,11 +1301,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1321,11 +1301,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes); num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max()); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL > 2); EP_HOST_ASSERT(num_channels * 2 > 2);
// Launch kernel // Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>; auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, num_threads, stream); SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE( LAUNCH_KERNEL_NON_COOPERATIVE(
&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
...@@ -1334,45 +1314,49 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1334,45 +1314,49 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team); cpu_rdma_team);
} }
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, bool kUseMLS, typename GetAddrFn, typename ReceiveTWFn> template <int kNumRanks, typename dtype_t, int kMaxNumRanks, int kWidth, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx, __device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk, int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights, int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads // Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank` // Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks"); EP_STATIC_ASSERT(kMaxNumRanks <= kWidth, "Too many ranks");
int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
#pragma unroll #pragma unroll
for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) { for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i, kWidth)) {
slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens; slot_indices[num_topk_ranks] = shfl_sync(head_idx, i, kWidth) % num_max_recv_tokens;
topk_ranks[num_topk_ranks ++] = i; topk_ranks[num_topk_ranks ++] = i;
} }
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
// Reduce data // Reduce data
#pragma unroll #pragma unroll
for (int i = lane_id; i < hidden_int4; i += kWarpSize) { for (int i = lane_id; i < hidden_int4; i += kWidth) {
// Read buffers float values[kDtypePerInt4] = {0};
float values[kDtypePerInt4] = {0}; // 8 × 4B = 32B
// Temporary buffer
int4 temp;
#pragma unroll #pragma unroll
for (int j = 0; j < num_topk_ranks; ++j) { for (int j = 0; j < num_topk_ranks; ++j) {
int4 recv_value = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i)); temp = recv_fn(topk_ranks[j], slot_indices[j], i);
auto recv_dtypes = reinterpret_cast<const dtype_t*>(&recv_value); const dtype_t* d = reinterpret_cast<const dtype_t*>(&temp);
#pragma unroll #pragma unroll
for (int k = 0; k < kDtypePerInt4; ++k) for (int k = 0; k < kDtypePerInt4; ++k)
values[k] += static_cast<float>(recv_dtypes[k]); values[k] += static_cast<float>(d[k]);
} }
// Cast back to `dtype_t` and write
int4 out_int4; int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4); dtype_t* out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll #pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j) for (int j = 0; j < kDtypePerInt4; ++j)
out_dtypes[j] = static_cast<dtype_t>(values[j]); out_dtypes[j] = static_cast<dtype_t>(values[j]);
st_na_global(combined_row + i, out_int4); st_na_global(combined_row + i, out_int4);
} }
...@@ -1389,87 +1373,98 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, ...@@ -1389,87 +1373,98 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
return topk_ranks[0]; return topk_ranks[0];
} }
template <bool kLowLatencyMode, template<bool kLowLatencyMode,
int kNumRDMARanks, int kNumRDMARanks, typename dtype_t,
typename dtype_t, int kNumCombineForwarderWarps,
int kNumCombineForwarderWarps, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks), int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, int kNumRDMAReceivers = kNumRDMARanks <=8 ? kNumForwarders + NUM_MAX_NVL_PEERS / 2: kNumForwarders + NUM_MAX_NVL_PEERS,
int kNumRDMAReceivers = kNumForwarders> int kBlockThreads = (kNumRDMARanks > 8) ? ((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWarpSize) : ((NUM_MAX_NVL_PEERS/2 + 1 + kNumForwarders) * kWarpSize) >
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) __global__ void __launch_bounds__(kBlockThreads, 1)
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank, combine(int4* combined_x, float* combined_topk_weights,
const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1, const bool* is_combined_token_in_rank,
const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta, const int4* x, const float* topk_weights, const int4 *bias_0, const int4 *bias_1,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum, const int* combined_rdma_head, const int* combined_nvl_head,
const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens, const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_tokens, int num_combined_tokens, int hidden, int num_topk,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_ranks) { int rank, int num_ranks) {
enum class WarpRole { enum class WarpRole {
kNVLSender, kNVLSender,
kNVLAndRDMAForwarder, kNVLAndRDMAForwarder,
kRDMAReceiver, kRDMAReceiver,
kRDMACoordinator, kCoordinator
kNVLCoordinator
}; };
constexpr auto kNVLPeersHyb = (kNumRDMARanks > 8) ? NUM_MAX_NVL_PEERS : NUM_MAX_NVL_PEERS / 2;
constexpr auto kWarpHyb = kNumRDMARanks > 8 ? kEmulatedWarpSize : kWarpSize;
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const int num_warps = kNumRDMARanks > 8 ? (num_threads / kEmulatedWarpSize - 1) : (num_threads / kWarpSize);
auto thread_id = static_cast<int>(threadIdx.x);
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const bool is_rdma_receiver_sm = sm_id % 2 == 1;
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__ shmem_ctx_t ctx; __shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx); shmem_wg_ctx_create(&ctx);
#endif
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
const auto sm_id = static_cast<int>(blockIdx.x); #endif
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
EP_DEVICE_ASSERT(num_topk <= kEmulatedWarpSize);
EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0);
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs // NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto role_meta = [=]() -> std::pair<WarpRole, int> {
const auto role_meta = [=]() -> std::pair<WarpRole, int> { auto warp_id = thread_id / kWarpHyb;
if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) { if (not is_rdma_receiver_sm) {
return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; if (warp_id < kNVLPeersHyb) {
} else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) { auto shuffled_warp_id = warp_id;
if(warp_id < kNumForwarders) { shuffled_warp_id = (shuffled_warp_id + channel_id) % kNVLPeersHyb;
return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders}; return {WarpRole::kNVLSender, shuffled_warp_id};
} else if (warp_id < kNVLPeersHyb + kNumForwarders) {
auto shuffled_warp_id = warp_id - kNVLPeersHyb;
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else { } else {
return {WarpRole::kRDMACoordinator, 0}; return {WarpRole::kCoordinator, 0};
} }
} else { } else {
if(warp_id < kNumForwarders) { if (warp_id < kNVLPeersHyb + kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id}; return {WarpRole::kRDMAReceiver, warp_id};
} else { } else {
return {WarpRole::kNVLCoordinator, 0}; return {WarpRole::kCoordinator, 0};
} }
} }
}(); }();
auto warp_role = role_meta.first; auto warp_role = role_meta.first;
auto target_rank = role_meta.second; // Not applicable for RDMA senders auto warp_id = role_meta.second;
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
// This approach is designed to sync multiple warps in a loop // This approach is designed to sync multiple warps in a loop
constexpr int num_sync_large_iteration = 64; constexpr int num_sync_large_iteration = 64;
constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration; __shared__ volatile int rdma_receiver_counter[1];
__shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters]; __shared__ volatile int rdma_forwarder_counter[1];
for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) { __shared__ volatile uint8_t sync_large_warp_counters[2 * kNumRDMARanks * num_sync_large_iteration ];
if (threadIdx.x==0){
rdma_receiver_counter[0] = 0;
rdma_forwarder_counter[0] = 0;
}
for (int i = thread_id; i < 2 * kNumRDMARanks * num_sync_large_iteration; i += num_threads) {
sync_large_warp_counters[i] = 0; sync_large_warp_counters[i] = 0;
} }
__syncthreads(); __syncthreads();
if (warp_role == WarpRole::kNVLSender) { if (warp_role == WarpRole::kNVLSender) {
if(warp_id >= NUM_MAX_NVL_PEERS) { // NVL producers
return; const int dst_nvl_rank = kNumRDMARanks <= 8 ? (warp_id * 2 + (thread_id % kWarpSize) / kEmulatedWarpSize) : warp_id;
} auto lane_id = get_lane_id() % kEmulatedWarpSize;
const auto dst_nvl_rank = target_rank;
// NVL layouts // NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
...@@ -1481,103 +1476,90 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1481,103 +1476,90 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Get tasks for each RDMA lane // Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0; int token_start_idx = 0, token_end_idx = 0;
if(lane_id < kNumRDMARanks) { if (lane_id < kNumRDMARanks) {
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
} }
syncwarp(); syncwarp();
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers"); EP_STATIC_ASSERT(kNumRDMARanks <= kEmulatedWarpSize, "Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks // Iterate over all tokens and send by chunks
while(true) { while (true) {
// Exit if possible // Exit if possible
if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx)) if (__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
break; break;
// Decide next RDMA buffer to send // Decide next RDMA buffer to send
bool is_lane_ready = false; bool is_lane_ready = false;
auto start_time = wall_clock64(); auto start_time = wall_clock64();
while (true) {
while(true) {
int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;
num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; if(__any_sync(kFirstHalfMask, is_lane_ready))
break;
if(__any_sync(kFullWarpMask, is_lane_ready)) if(__any_sync(kSecondHalfMask, is_lane_ready))
break; break;
// Retry // Retry
if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx) if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
// Timeout check // Timeout check
if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { long long int elapsed_time = wall_clock64() > start_time ? wall_clock64() - start_time : 0;
printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, " if (elapsed_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
"RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n", printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
channel_id, channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx,
rdma_rank, token_start_idx, token_end_idx);
nvl_rank,
dst_nvl_rank,
lane_id,
ld_volatile_global(nvl_channel_head.buffer() + lane_id),
cached_channel_tail_idx,
token_start_idx,
token_end_idx);
trap(); trap();
} }
__builtin_amdgcn_s_sleep(1);
} }
// Sync token start index and count // Sync token start index and count
for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) { for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) {
if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) if (shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx, kEmulatedWarpSize))
continue; continue;
// Sync token start index // Sync token start index
auto token_idx = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx)); auto token_idx = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx, kEmulatedWarpSize));
int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx, kEmulatedWarpSize);
// Send by chunk // Send by chunk
for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) { for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++ chunk_idx, ++ token_idx) {
// Get an empty slot // Get an empty slot
int dst_slot_idx = 0; int dst_slot_idx = 0;
if(lane_id == current_rdma_idx) { if (lane_id == current_rdma_idx) {
dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma; dst_slot_idx = (cached_channel_tail_idx ++) % num_max_nvl_chunked_recv_tokens_per_rdma;
dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
} }
dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx); dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx, kEmulatedWarpSize);
// Copy data // Copy data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_EMULATED(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Copy source meta // Copy source meta
if(lane_id == 0) if (lane_id == num_topk)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx)); st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
// Copy `topk_weights` // Copy `topk_weights`
if(lane_id < num_topk) if (lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
} }
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0; lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
} }
// Move queue tail // Move queue tail
syncwarp(); syncwarp();
if(lane_id < kNumRDMARanks and is_lane_ready) { if (lane_id < kNumRDMARanks and is_lane_ready)
st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); st_relaxed_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
}
} }
} else { } else {
if(warp_id > kNumForwarders) { auto lane_id = get_lane_id() % kWarpHyb;
return;
}
// Combiners and coordinators // Combiners and coordinators
// RDMA symmetric layout // RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4); auto hidden_bytes = hidden_int4 * sizeof(int4);
...@@ -1604,53 +1586,65 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1604,53 +1586,65 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
auto sync_forwarder_smem = [&]() {
if (lane_id==0) {
volatile int ret = __hip_atomic_fetch_add(&rdma_forwarder_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
}
syncwarp();
while(rdma_forwarder_counter[0]<(kNumForwarders + 1)){}
};
auto sync_rdma_receiver_smem = [&]() {
if (lane_id==0) {
volatile int ret = __hip_atomic_fetch_add(&rdma_receiver_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
}
syncwarp();
while(rdma_receiver_counter[0]<(kNumRDMAReceivers+1)){}
};
if (warp_role == WarpRole::kNVLAndRDMAForwarder) { if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
// Receive from NVL ranks and forward to RDMA ranks // Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks // NOTES: this part is using "large warps" for each RDMA ranks
const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder; const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder;
const auto sub_warp_id = target_rank % kNumWarpsPerForwarder; const auto sub_warp_id = warp_id % kNumWarpsPerForwarder;
auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank); auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
// auto sync_large_warp = [=]() {
// if(kNumWarpsPerForwarder == 1) {
// syncwarp();
// } else {
// // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
// // __syncthreads();
// syncwarp();
// }
// };
auto sync_large_warp = [=](const int iter, const int mode) { auto sync_large_warp = [=](const int iter, const int mode) {
if (kNumWarpsPerForwarder == 1) { if (kNumWarpsPerForwarder == 1) {
syncwarp(); syncwarp();
} else { } else {
// LDS index to store for sync // LDS index to store for sync
int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters; int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * kNumRDMARanks * num_sync_large_iteration;
//reset index in the LDS to avoid race condition due to warp scheduling //reset index in the LDS to avoid race condition due to warp scheduling
int reset_idx = dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters; int reset_idx = dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * kNumRDMARanks * num_sync_large_iteration;
auto start_time = wall_clock64(); // if (lane_id==0)
if (lane_id == 0){ // printf("rank %d dst_rdma_rank %d iter %d warp_id %d val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1); auto start_time = clock64();
} if (lane_id == 0){
syncwarp(); volatile int ret = __hip_atomic_fetch_add(
//The while(...) loop polls the counter until all warps have arrived &sync_large_warp_counters[lds_dst_rdma_rank], 1,
if (lane_id == 0){ __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){ }
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { syncwarp();
printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration ); //The while(...) loop polls the counter until all warps have arrived
trap(); if (lane_id == 0){
} while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
trap();
} }
} }
syncwarp(); }
if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){ syncwarp();
sync_large_warp_counters[reset_idx] = 0; if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
} sync_large_warp_counters[reset_idx] = 0;
syncwarp(); }
syncwarp();
} }
}; };
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough"); EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough");
// Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移 // In case of running less than 8 nodes
constexpr bool kUseWave = (kNumRDMARanks <= 8);
// Advance to the corresponding NVL buffer
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4); nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma); nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk); nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
...@@ -1659,10 +1653,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1659,10 +1653,9 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Clean shared memory and sync // Clean shared memory and sync
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0; lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (forwarder_retired[target_rank] = false) : false; lane_id == 0 ? (forwarder_retired[warp_id] = false) : false;
// sync_forwarder_smem(); sync_forwarder_smem();
__syncthreads();
// Get count and cached head // Get count and cached head
int cached_nvl_channel_tail_idx = 0; int cached_nvl_channel_tail_idx = 0;
...@@ -1673,89 +1666,106 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1673,89 +1666,106 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;
// Iterate over all tokens and combine by chunks // Iterate over all tokens and combine by chunks
for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) { for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
// Check destination queue emptiness, or wait a buffer to be released // Check destination queue emptiness, or wait a buffer to be released
auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
auto num_chunked_tokens = token_end_idx - token_start_idx; auto num_chunked_tokens = token_end_idx - token_start_idx;
auto start_time = wall_clock64(); auto start_time = wall_clock64();
while(sub_warp_id == 0 and lane_id == 0) { while (sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail // Here, `token_start_idx` is the actual tail
int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
break; break;
// Timeout check // Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { long long int elapsed_time = wall_clock64() > start_time ? wall_clock64() - start_time : 0;
if (elapsed_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n", printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
trap(); trap();
} }
} }
// sync_large_warp();
sync_large_warp(token_start_idx, 0); sync_large_warp(token_start_idx, 0);
// Combine and write to the RDMA buffer // Combine and write to the RDMA buffer
for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) { for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
// Read expected head // Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers"); EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1; int expected_head = -1;
if(lane_id < NUM_MAX_NVL_PEERS) if (lane_id < NUM_MAX_NVL_PEERS)
expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
// Wait lanes to be ready // Wait lanes to be ready
start_time = wall_clock64(); start_time = wall_clock64();
while(cached_nvl_channel_tail_idx <= expected_head) { while (cached_nvl_channel_tail_idx <= expected_head) {
cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); cached_nvl_channel_tail_idx = ld_relaxed_sys_global(nvl_channel_tail.buffer(lane_id));
// Timeout check // Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { long long int elapsed_time = wall_clock64() > start_time ? wall_clock64() - start_time : 0;
if (elapsed_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head); channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
trap(); trap();
} }
__builtin_amdgcn_s_sleep(1);
} }
// Combine current token // Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4) + hidden_int4_idx; }; auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS, true>(expected_head >= 0, combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS, kWarpHyb>(expected_head >= 0,
expected_head, lane_id, expected_head, lane_id,
hidden_int4, num_topk, hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted), reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)), reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
get_addr_fn, recv_tw_fn);
// Update head // Update head
if(lane_id < NUM_MAX_NVL_PEERS) { if (lane_id < NUM_MAX_NVL_PEERS)
expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1) expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1);
: (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
}
} }
// sync_large_warp();
sync_large_warp(token_start_idx, 1); sync_large_warp(token_start_idx, 1);
// Issue RDMA send // Issue RDMA send
if(sub_warp_id == kNumWarpsPerForwarder - 1) { // TODO: Switch back to put_nbi_wave function
if(dst_rdma_rank != rdma_rank) {
if (sub_warp_id == kNumWarpsPerForwarder - 1 ) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) #ifdef FORCE_DUSHMEM_API
shmem_ctx_schar_put_nbi_warp(ctx,
#else
shmemx_int8_put_nbi_warp( shmemx_int8_put_nbi_warp(
#endif
rdma_channel_data.recv_buffer(rdma_rank) + rdma_channel_data.recv_buffer(rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token, rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_channel_data.send_buffer(dst_rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token, rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token, num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
#else
#if !defined(ROCM_DISABLE_CTX)
if constexpr (kUseWave){
shmem_ctx_schar_put_nbi_warp(ctx,
rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
#else
if constexpr (kUseWave){
shmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
} else {
if (lane_id == 0)
shmemx_int8_put_nbi(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
#endif
#endif
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet(ctx); shmem_ctx_quiet(ctx);
#else #else
shmem_fence(); shmem_fence();
#endif #endif
...@@ -1765,167 +1775,100 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1765,167 +1775,100 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail // Write new RDMA tail
syncwarp(); syncwarp();
if(lane_id == 0) { if (lane_id == 0)
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx, shmem_ctx_ulong_atomic_add(ctx,
#else #else
shmem_signal_op_add( shmem_signal_op_add(
#endif #endif
rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
} }
} }
// Retired // Retired
syncwarp(); syncwarp();
if(lane_id == 0) { if (lane_id == 0)
forwarder_retired[target_rank] = true; forwarder_retired[warp_id] = true;
} } else if (warp_role == WarpRole::kRDMAReceiver) {
} else if (warp_role == WarpRole::kRDMACoordinator) {
// Coordinator
// Sync shared memory status
// sync_forwarder_smem();
__syncthreads();
constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
while(true) {
// Retired
if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break;
{
// Find minimum head for NVL ranks
#pragma unroll
for(int i = 0; i < kNumRDMARanks; ++i) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int j = 0; j < num_warps_per_rdma_rank; ++j)
if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
}
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
} else if(warp_role == WarpRole::kRDMAReceiver) {
// Receive from RDMA ranks and write to the output tensor // Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync // Clean shared memory and sync
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize); EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0; lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0;
lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0; lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0;
// sync_rdma_receiver_smem(); sync_rdma_receiver_smem();
__syncthreads();
// The same tokens as the dispatch process // The same tokens as the dispatch process
int token_start_idx, token_end_idx; int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// ==================== Token 级展开 x4 ==================== // Iterate over all tokens and combine
constexpr int kTokenUnroll = 4;
int cached_channel_tail_idx = 0; int cached_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
for (int64_t base = token_start_idx + target_rank; // Read expected head
base < token_end_idx; EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
base += (int64_t)kNumRDMAReceivers * kTokenUnroll) { int expected_head = -1;
if (lane_id < kNumRDMARanks) {
// ---- Phase 1: 批量预取所有 token 的 expected_head ---- expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
int cached_expected_head[kTokenUnroll]; (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head);
int max_expected_head = -1;
#pragma unroll
for (int u = 0; u < kTokenUnroll; ++u) {
int64_t tidx = base + (int64_t)u * kNumRDMAReceivers;
cached_expected_head[u] = -1;
if (tidx < token_end_idx && lane_id < kNumRDMARanks) {
int expected_head = ld_nc_global(combined_rdma_head + tidx * kNumRDMARanks + lane_id);
cached_expected_head[u] = expected_head;
if (expected_head > max_expected_head) max_expected_head = expected_head;
}
} }
// ---- Phase 2: 一次等待,覆盖所有 token ---- // Wait lanes to be ready
if (max_expected_head >= 0) { auto start_time = wall_clock64();
auto start_time = wall_clock64(); while (cached_channel_tail_idx <= expected_head) {
while (cached_channel_tail_idx <= max_expected_head) { cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
cached_channel_tail_idx = static_cast<int>(
ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); // Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) { long long int elapsed_time = wall_clock64() > start_time ? wall_clock64() - start_time : 0;
printf("DeepEP combine RDMA receiver timeout (unroll x%d), " if (elapsed_time > NUM_TIMEOUT_CYCLES) {
"ch: %d, rdma: %d, nvl: %d, lane: %d, " printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
"tail: %d, wait: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
kTokenUnroll, channel_id, rdma_rank, nvl_rank, trap();
lane_id, cached_channel_tail_idx, max_expected_head);
trap();
}
} }
__builtin_amdgcn_s_sleep(1);
} }
syncwarp(); syncwarp();
// ---- Phase 3: 批量处理所有就绪 token ---- // Combine current token
#pragma unroll auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
for (int u = 0; u < kTokenUnroll; ++u) { auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
int64_t tidx = base + (int64_t)u * kNumRDMAReceivers; combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, kWarpHyb>(expected_head >= 0,
if (tidx < token_end_idx) { expected_head, lane_id,
int expected_head = cached_expected_head[u]; hidden_int4, num_topk,
// Combine current token combined_x + token_idx * hidden_int4,
auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx; }; combined_topk_weights + token_idx * num_topk,
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, false>(
expected_head >= 0, expected_head, lane_id,
hidden_int4, num_topk,
combined_x + tidx * hidden_int4,
combined_topk_weights + tidx * num_topk,
num_max_rdma_chunked_recv_tokens,
get_addr_fn, recv_tw_fn);
if (lane_id < kNumRDMARanks) {
rdma_receiver_rdma_head[target_rank][lane_id] =
expected_head < 0 ? -expected_head - 1 : expected_head;
}
}
}
} }
// Retired // Retired
syncwarp(); syncwarp();
if (lane_id == 0) { if (lane_id == 0)
rdma_receiver_retired[target_rank] = true; rdma_receiver_retired[warp_id] = true;
} } else {
} else if(warp_role == WarpRole::kNVLCoordinator) { auto lane_id = get_lane_id();
// Coordinator // Coordinator
// Sync shared memory status // Sync shared memory status
// sync_rdma_receiver_smem(); is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem();
__syncthreads();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0; int last_rdma_head = 0;
int last_nvl_head[kNumRDMARanks] = {0}; int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) { while (true) {
// Retired // Retired
if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) if (is_rdma_receiver_sm and __all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break;
if (not is_rdma_receiver_sm and __all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break; break;
// Find minimum head for RDMA ranks // Find minimum head for RDMA ranks
{ if (is_rdma_receiver_sm) {
int min_head = std::numeric_limits<int>::max(); int min_head = std::numeric_limits<int>::max();
#pragma unroll #pragma unroll
for(int i = 0; i < kNumRDMAReceivers; ++i) for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
if(not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX) #if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add(ctx, shmem_ctx_ulong_atomic_add(ctx,
...@@ -1933,10 +1876,21 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1933,10 +1876,21 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
shmem_signal_op_add( shmem_signal_op_add(
#endif #endif
rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head; last_rdma_head = min_head;
} }
} else {
// Find minimum head for NVL ranks
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++ i) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int j = 0; j < num_warps_per_rdma_rank; ++ j) if (not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS)
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
}
} }
// Nanosleep and let other warps work // Nanosleep and let other warps work
...@@ -1949,46 +1903,80 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1949,46 +1903,80 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
#endif #endif
} }
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
const void *bias_0, const void *bias_1, const int *combined_rdma_head,
const int *combined_nvl_head, const void *src_meta,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 8;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ void combine(hipDataType type,
{ \ void* combined_x,
auto combine_func = \ float* combined_topk_weights,
low_latency_mode \ const bool* is_combined_token_in_rank,
? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps> \ const void* x,
: combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>; \ const float* topk_weights,
LAUNCH_KERNEL_NON_COOPERATIVE( \ const void* bias_0,
&cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights, \ const void* bias_1,
is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights, \ const int* combined_rdma_head,
reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1), \ const int* combined_nvl_head,
combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \ const void* src_meta,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ const int* rdma_channel_prefix_matrix,
num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, \ const int* rdma_rank_prefix_sum,
num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \ const int* gbl_channel_prefix_matrix,
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \ int num_tokens,
} \ int num_combined_tokens,
break int hidden,
int num_topk,
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; void* rdma_buffer_ptr,
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); int num_max_rdma_chunked_send_tokens,
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; int num_max_rdma_chunked_recv_tokens,
EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS); void** buffer_ptrs,
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens,
int rank,
int num_ranks,
hipStream_t stream,
int num_channels,
bool low_latency_mode)
{
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
EP_HOST_ASSERT(num_rdma_ranks > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(type == HIP_R_16BF); EP_HOST_ASSERT(type == HIP_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream); // One case per compile-time NR specialization.
#define COMBINE_LAUNCH_CASE(NR) { \
/* Per-case compile-time constants */ \
constexpr int kNumCombineForwarderWarps = (NR < 9) ? 10 : 16; \
constexpr int kWarpsPerForwarder = (kNumCombineForwarderWarps/NR) > 0 \
? (kNumCombineForwarderWarps/NR) : 1; \
constexpr int kNumForwarders = NR * kWarpsPerForwarder; \
constexpr int kBlockThreads = (NR > 8) \
? ((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWarpSize) \
: ((NUM_MAX_NVL_PEERS/2 + 1 + kNumForwarders) * kWarpSize); \
\
SETUP_LAUNCH_CONFIG(num_channels * 2, kBlockThreads, stream); \
\
using scalar_t = hip_bfloat16; \
auto fn = low_latency_mode \
? combine<true, NR, scalar_t, kNumCombineForwarderWarps> \
: combine<false, NR, scalar_t, kNumCombineForwarderWarps>; \
\
/* Launch (backend-specific) */ \
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, fn, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \
reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1), \
combined_rdma_head, combined_nvl_head, \
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, \
rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, \
num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, \
num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); \
} break
// Dispatch on the runtime num_rdma_ranks, but each case is compile-time specialized.
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE #undef COMBINE_LAUNCH_CASE
} }
...@@ -1997,8 +1985,4 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights, ...@@ -1997,8 +1985,4 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
} // namespace deep_ep } // namespace deep_ep
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment