Unverified Commit 06f417dc authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Use TMA to optimize internode combine. (#287)



* Let forwarders use a dedicated SM

* Shuffle rdma idx

* Sender use TMA.

* Adjust the tuning chunk size.

* Modify NVL chunk layout.

* Update some combine config.

* Small lint

* Minor fix

* Overlap TMA store

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 1cf85fb2
...@@ -39,7 +39,7 @@ int get_source_meta_bytes() { ...@@ -39,7 +39,7 @@ int get_source_meta_bytes() {
} }
__host__ __device__ __forceinline__ __host__ __device__ __forceinline__
int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
return static_cast<int>(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); return static_cast<int>(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
} }
...@@ -49,7 +49,7 @@ std::pair<int, int> get_rdma_clean_meta(int hidden_int4, int num_scales, int num ...@@ -49,7 +49,7 @@ std::pair<int, int> get_rdma_clean_meta(int hidden_int4, int num_scales, int num
int num_channels) { int num_channels) {
// Return `int32_t` offset and count to clean // Return `int32_t` offset and count to clean
return { return {
(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int), (get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels
}; };
} }
...@@ -59,13 +59,10 @@ std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_ ...@@ -59,13 +59,10 @@ std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_
int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks,
int num_nvl_recv_buffer_tokens, int num_channels, bool is_dispatch) { int num_nvl_recv_buffer_tokens, int num_channels, bool is_dispatch) {
// Return `int32_t` offset and to clean // Return `int32_t` offset and to clean
// TODO: remove `is_dispatch` after finishing combine refactor
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
const int num_bytes_per_token = is_dispatch ?
get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) :
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta));
return { return {
(num_nvl_recv_buffer_tokens * num_bytes_per_token * num_nvl_ranks * num_channels) / sizeof(int), (num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks * num_channels) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,
}; };
} }
...@@ -401,8 +398,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -401,8 +398,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto hidden_bytes = hidden_int4 * sizeof(int4); auto hidden_bytes = hidden_int4 * sizeof(int4);
auto scale_bytes = num_scales * sizeof(float); auto scale_bytes = num_scales * sizeof(float);
// TODO: rename `num_bytes_per_rdma_token` after combine refactor // TODO: rename `num_bytes_per_rdma_token` after combine refactor
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data = SymBuffer<uint8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_data = SymBuffer<uint8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
...@@ -417,7 +414,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -417,7 +414,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
// Allocate buffers // Allocate buffers
auto nvl_channel_x = AsymBuffer<uint8_t>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_rdma_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_x = AsymBuffer<uint8_t>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
...@@ -440,7 +437,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -440,7 +437,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
mbarrier_init(tma_mbarrier, 1); mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared(); fence_view_async_shared();
fence_barrier_init(); fence_barrier_init();
EP_DEVICE_ASSERT(num_bytes_per_rdma_token + sizeof(uint64_t) <= kNumTMABytesPerWarp); EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerWarp);
} }
__syncwarp(); __syncwarp();
...@@ -528,7 +525,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -528,7 +525,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64); auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);
if (lane_id == num_topk_ranks) if (lane_id == num_topk_ranks)
src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token; dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_token;
} }
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
...@@ -656,9 +653,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -656,9 +653,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (dst_rdma_rank != rdma_rank) { if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; const size_t num_bytes_per_msg = num_bytes_per_token * num_tokens_to_issue;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_token);
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg, nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else { } else {
...@@ -771,7 +768,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -771,7 +768,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate over every token from the RDMA buffer // Iterate over every token from the RDMA buffer
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) { for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes)); auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
...@@ -786,17 +783,17 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -786,17 +783,17 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Get an empty slot // Get an empty slot
int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens; int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;
auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_rdma_token; auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
// Copy data // Copy data
if (lane_id == 0) { if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_rdma_token, false); tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false);
mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_rdma_token); mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token);
} }
__syncwarp(); __syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase); mbarrier_wait(tma_mbarrier, tma_phase);
if (lane_id == 0) if (lane_id == 0)
tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_rdma_token); tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_token);
__syncwarp(); __syncwarp();
// In case of insufficient NVL buffers, early stopping // In case of insufficient NVL buffers, early stopping
...@@ -918,7 +915,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -918,7 +915,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) { for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) {
int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens; int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens;
auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_rdma_token; auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_token;
auto meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes)); auto meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));
int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
...@@ -1247,11 +1244,12 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, ...@@ -1247,11 +1244,12 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
template<bool kLowLatencyMode, template<bool kLowLatencyMode,
int kNumRDMARanks, typename dtype_t, int kNumRDMARanks, typename dtype_t,
int kNumCombineForwarderWarps, int kNumCombineForwarderWarps,
int kNumTMABytesPerWarp,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks), int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS>
__global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1) __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1)
combine(int4* combined_x, float* combined_topk_weights, combine(int4* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank, const bool* is_combined_token_in_rank,
const int4* x, const float* topk_weights, const int4* x, const float* topk_weights,
...@@ -1273,31 +1271,32 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1273,31 +1271,32 @@ combine(int4* combined_x, float* combined_topk_weights,
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32; const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id(); const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2; const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const bool is_rdma_receiver_sm = sm_id % 2 == 1; const bool is_forwarder_sm = sm_id % 2 == 1;
EP_DEVICE_ASSERT(num_topk <= 32); EP_DEVICE_ASSERT(num_topk <= 32);
EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0);
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
const auto hidden_bytes = hidden_int4 * sizeof(int4);
const auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, 0, 0, num_topk);
// NOTES: we decouple a channel into 2 SMs // NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto role_meta = [=]() -> std::pair<WarpRole, int> { auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / 32; auto warp_id = thread_id / 32;
if (not is_rdma_receiver_sm) { if (not is_forwarder_sm) {
if (warp_id < NUM_MAX_NVL_PEERS) { if (warp_id < NUM_MAX_NVL_PEERS) {
auto shuffled_warp_id = warp_id; auto shuffled_warp_id = warp_id;
shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS;
return {WarpRole::kNVLSender, shuffled_warp_id}; return {WarpRole::kNVLSender, shuffled_warp_id};
} else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { } else if (warp_id < kNumForwarders) {
auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; return {WarpRole::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS};
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else { } else {
return {WarpRole::kCoordinator, 0}; return {WarpRole::kCoordinator, 0};
} }
} else { } else {
if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { if (warp_id < kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id}; auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else { } else {
return {WarpRole::kCoordinator, 0}; return {WarpRole::kCoordinator, 0};
} }
...@@ -1306,7 +1305,7 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1306,7 +1305,7 @@ combine(int4* combined_x, float* combined_topk_weights,
auto warp_role = role_meta.first; auto warp_role = role_meta.first;
auto warp_id = role_meta.second; auto warp_id = role_meta.second;
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1);
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
if (warp_role == WarpRole::kNVLSender) { if (warp_role == WarpRole::kNVLSender) {
...@@ -1316,12 +1315,23 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1316,12 +1315,23 @@ combine(int4* combined_x, float* combined_topk_weights,
// NVL layouts // NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); auto nvl_channel_x = AsymBuffer<uint8_t>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr); auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + hidden_bytes);
uint32_t tma_phase = 0;
if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
// Get tasks for each RDMA lane // Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0; int token_start_idx = 0, token_end_idx = 0;
if (lane_id < kNumRDMARanks) { if (lane_id < kNumRDMARanks) {
...@@ -1336,6 +1346,7 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1336,6 +1346,7 @@ combine(int4* combined_x, float* combined_topk_weights,
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks // Iterate over all tokens and send by chunks
int current_rdma_idx = channel_id % kNumRDMARanks;
while (true) { while (true) {
// Exit if possible // Exit if possible
if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))
...@@ -1364,7 +1375,8 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1364,7 +1375,8 @@ combine(int4* combined_x, float* combined_topk_weights,
} }
// Sync token start index and count // Sync token start index and count
for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) { for (int i = 0; i < kNumRDMARanks; ++ i) {
current_rdma_idx = (current_rdma_idx + 1) % kNumRDMARanks;
if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
continue; continue;
...@@ -1382,23 +1394,36 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1382,23 +1394,36 @@ combine(int4* combined_x, float* combined_topk_weights,
} }
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx);
// Copy data // Load data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
auto shifted_x = x + token_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); if (lane_id == 0) {
tma_store_wait();
tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);
// Copy source meta // Load source meta
if (lane_id == 0) if (lane_id == num_topk)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx)); *reinterpret_cast<SourceMeta*>(tma_buffer + hidden_bytes) = ld_nc_global(src_meta + token_idx);
// Copy `topk_weights` // Load `topk_weights`
if (lane_id < num_topk) if (lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id)); *reinterpret_cast<float*>(tma_buffer + hidden_bytes + sizeof(SourceMeta) + lane_id * sizeof(float)) = ld_nc_global(topk_weights + token_idx * num_topk + lane_id);
// Issue TMA store
tma_store_fence();
__syncwarp();
if (lane_id == 0)
tma_store_1d(tma_buffer, shifted_x_buffers, num_bytes_per_token, false);
} }
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0; lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
} }
// Move queue tail // Move queue tail
tma_store_wait();
__syncwarp(); __syncwarp();
if (lane_id < kNumRDMARanks and is_lane_ready) if (lane_id < kNumRDMARanks and is_lane_ready)
st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
...@@ -1406,9 +1431,7 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1406,9 +1431,7 @@ combine(int4* combined_x, float* combined_topk_weights,
} else { } else {
// Combiners and coordinators // Combiners and coordinators
// RDMA symmetric layout // RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4); auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
...@@ -1418,9 +1441,7 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1418,9 +1441,7 @@ combine(int4* combined_x, float* combined_topk_weights,
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
nvl_buffers[i] = buffer_ptrs[i]; nvl_buffers[i] = buffer_ptrs[i];
auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers); auto nvl_channel_x = AsymBuffer<uint8_t>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer); auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers); auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
...@@ -1448,9 +1469,7 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1448,9 +1469,7 @@ combine(int4* combined_x, float* combined_topk_weights,
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough");
// Advance to the corresponding NVL buffer // Advance to the corresponding NVL buffer
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4); nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_bytes_per_token);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
nvl_channel_head.advance(dst_rdma_rank); nvl_channel_head.advance(dst_rdma_rank);
nvl_channel_tail.advance(dst_rdma_rank); nvl_channel_tail.advance(dst_rdma_rank);
...@@ -1513,9 +1532,9 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1513,9 +1532,9 @@ combine(int4* combined_x, float* combined_topk_weights,
// Combine current token // Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); }; auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<float*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0, combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id, expected_head, lane_id,
hidden_int4, num_topk, hidden_int4, num_topk,
...@@ -1533,9 +1552,9 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1533,9 +1552,9 @@ combine(int4* combined_x, float* combined_topk_weights,
if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (dst_rdma_rank != rdma_rank) { if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token);
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg, nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else { } else {
...@@ -1593,8 +1612,8 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1593,8 +1612,8 @@ combine(int4* combined_x, float* combined_topk_weights,
__syncwarp(); __syncwarp();
// Combine current token // Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);}; auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks>(expected_head >= 0, combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
expected_head, lane_id, expected_head, lane_id,
hidden_int4, num_topk, hidden_int4, num_topk,
...@@ -1612,7 +1631,7 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1612,7 +1631,7 @@ combine(int4* combined_x, float* combined_topk_weights,
} else { } else {
// Coordinator // Coordinator
// Sync shared memory status // Sync shared memory status
is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem(); is_forwarder_sm ? sync_forwarder_smem() : sync_rdma_receiver_smem();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0; int last_rdma_head = 0;
...@@ -1622,13 +1641,13 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1622,13 +1641,13 @@ combine(int4* combined_x, float* combined_topk_weights,
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps"); EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps");
while (true) { while (true) {
// Retired // Retired
if (is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) if (not is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break; break;
if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id])) if (is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break; break;
// Find minimum head for RDMA ranks // Find minimum head for RDMA ranks
if (is_rdma_receiver_sm) { if (not is_forwarder_sm) {
int min_head = std::numeric_limits<int>::max(); int min_head = std::numeric_limits<int>::max();
#pragma unroll #pragma unroll
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
...@@ -1670,10 +1689,13 @@ void combine(cudaDataType_t type, ...@@ -1670,10 +1689,13 @@ void combine(cudaDataType_t type,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) { int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 16; constexpr int kNumCombineForwarderWarps = 16;
constexpr int kNumTMABytesPerWarp = 16384;
constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \ #define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \
auto combine_func = low_latency_mode ? \ auto combine_func = low_latency_mode ? \
combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps>; \ combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerWarp> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(combine_func); \
LAUNCH_KERNEL(&cfg, combine_func, \ LAUNCH_KERNEL(&cfg, combine_func, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \ reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \ reinterpret_cast<const int4*>(x), topk_weights, \
...@@ -1688,12 +1710,12 @@ void combine(cudaDataType_t type, ...@@ -1688,12 +1710,12 @@ void combine(cudaDataType_t type,
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(type == CUDA_R_16BF); EP_HOST_ASSERT(type == CUDA_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream); SETUP_LAUNCH_CONFIG(num_channels * 2, (num_forwarder_warps + 1) * 32, stream);
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE #undef COMBINE_LAUNCH_CASE
} }
......
...@@ -231,9 +231,9 @@ class Buffer: ...@@ -231,9 +231,9 @@ class Buffer:
2: Config(Buffer.num_sms, 10, 256, 6, 128), 2: Config(Buffer.num_sms, 10, 256, 6, 128),
4: Config(Buffer.num_sms, 9, 256, 6, 128), 4: Config(Buffer.num_sms, 9, 256, 6, 128),
8: Config(Buffer.num_sms, 4, 256, 6, 128), 8: Config(Buffer.num_sms, 4, 256, 6, 128),
16: Config(Buffer.num_sms, 2, 288, 28, 128), 16: Config(Buffer.num_sms, 4, 288, 16, 128),
24: Config(Buffer.num_sms, 1, 288, 20, 128), 24: Config(Buffer.num_sms, 1, 288, 8, 128),
32: Config(Buffer.num_sms, 1, 288, 20, 128), 32: Config(Buffer.num_sms, 1, 288, 8, 128),
64: Config(Buffer.num_sms, 1, 288, 20, 128), 64: Config(Buffer.num_sms, 1, 288, 20, 128),
128: Config(Buffer.num_sms, 1, 560, 12, 128), 128: Config(Buffer.num_sms, 1, 560, 12, 128),
144: Config(Buffer.num_sms, 2, 720, 8, 128), 144: Config(Buffer.num_sms, 2, 720, 8, 128),
......
...@@ -209,7 +209,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -209,7 +209,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Tune combine performance # Tune combine performance
best_time, best_results = 1e10, None best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1): for nvl_chunk_size in range(1, 13, 1):
for rdma_chunk_size in range(8, 33, 4): for rdma_chunk_size in range(8, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config} tune_args = {'x': recv_x, 'handle': handle, 'config': config}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment