"vscode:/vscode.git/clone" did not exist on "59484a6fb482160c54d6d89d7324dc66c1d6fc79"
Unverified Commit 06f417dc authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Use TMA to optimize internode combine. (#287)



* Let forwarders use a dedicated SM

* Shuffle rdma idx

* Sender use TMA.

* Adjust the tuning chunk size.

* Modify NVL chunk layout.

* Update some combine config.

* Small lint

* Minor fix

* Overlap TMA store

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 1cf85fb2
......@@ -39,7 +39,7 @@ int get_source_meta_bytes() {
}
__host__ __device__ __forceinline__
int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
return static_cast<int>(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
}
......@@ -49,7 +49,7 @@ std::pair<int, int> get_rdma_clean_meta(int hidden_int4, int num_scales, int num
int num_channels) {
// Return `int32_t` offset and count to clean
return {
(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int),
(get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels
};
}
......@@ -59,13 +59,10 @@ std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_
int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks,
int num_nvl_recv_buffer_tokens, int num_channels, bool is_dispatch) {
// Return `int32_t` offset and to clean
// TODO: remove `is_dispatch` after finishing combine refactor
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
const int num_bytes_per_token = is_dispatch ?
get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) :
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta));
return {
(num_nvl_recv_buffer_tokens * num_bytes_per_token * num_nvl_ranks * num_channels) / sizeof(int),
(num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks * num_channels) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,
};
}
......@@ -401,8 +398,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto scale_bytes = num_scales * sizeof(float);
// TODO: rename `num_bytes_per_rdma_token` after combine refactor
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data = SymBuffer<uint8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data = SymBuffer<uint8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
......@@ -417,7 +414,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
// Allocate buffers
auto nvl_channel_x = AsymBuffer<uint8_t>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_rdma_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_x = AsymBuffer<uint8_t>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
......@@ -440,7 +437,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(num_bytes_per_rdma_token + sizeof(uint64_t) <= kNumTMABytesPerWarp);
EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
......@@ -528,7 +525,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);
if (lane_id == num_topk_ranks)
src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_token;
}
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
......@@ -656,9 +653,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
const size_t num_bytes_per_msg = num_bytes_per_token * num_tokens_to_issue;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_token);
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else {
......@@ -771,7 +768,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate over every token from the RDMA buffer
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
......@@ -786,17 +783,17 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Get an empty slot
int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;
auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_rdma_token;
auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
// Copy data
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_rdma_token, false);
mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_rdma_token);
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false);
mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);
if (lane_id == 0)
tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_rdma_token);
tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_token);
__syncwarp();
// In case of insufficient NVL buffers, early stopping
......@@ -918,7 +915,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) {
int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens;
auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_rdma_token;
auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_token;
auto meta = ld_nc_global(reinterpret_cast<SourceMeta*>(shifted + hidden_bytes + scale_bytes));
int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
......@@ -1247,11 +1244,12 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
template<bool kLowLatencyMode,
int kNumRDMARanks, typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTMABytesPerWarp,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS>
__global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1)
int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS>
__global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1)
combine(int4* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const int4* x, const float* topk_weights,
......@@ -1273,31 +1271,32 @@ combine(int4* combined_x, float* combined_topk_weights,
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const bool is_rdma_receiver_sm = sm_id % 2 == 1;
const bool is_forwarder_sm = sm_id % 2 == 1;
EP_DEVICE_ASSERT(num_topk <= 32);
EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0);
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
const auto hidden_bytes = hidden_int4 * sizeof(int4);
const auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, 0, 0, num_topk);
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / 32;
if (not is_rdma_receiver_sm) {
if (not is_forwarder_sm) {
if (warp_id < NUM_MAX_NVL_PEERS) {
auto shuffled_warp_id = warp_id;
shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS;
return {WarpRole::kNVLSender, shuffled_warp_id};
} else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS;
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else if (warp_id < kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS};
} else {
return {WarpRole::kCoordinator, 0};
}
} else {
if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id};
if (warp_id < kNumForwarders) {
auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else {
return {WarpRole::kCoordinator, 0};
}
......@@ -1306,7 +1305,7 @@ combine(int4* combined_x, float* combined_topk_weights,
auto warp_role = role_meta.first;
auto warp_id = role_meta.second;
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1);
EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1);
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
if (warp_role == WarpRole::kNVLSender) {
......@@ -1316,12 +1315,23 @@ combine(int4* combined_x, float* combined_topk_weights,
// NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_x = AsymBuffer<uint8_t>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + hidden_bytes);
uint32_t tma_phase = 0;
if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
// Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0;
if (lane_id < kNumRDMARanks) {
......@@ -1336,6 +1346,7 @@ combine(int4* combined_x, float* combined_topk_weights,
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks
int current_rdma_idx = channel_id % kNumRDMARanks;
while (true) {
// Exit if possible
if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))
......@@ -1364,7 +1375,8 @@ combine(int4* combined_x, float* combined_topk_weights,
}
// Sync token start index and count
for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) {
for (int i = 0; i < kNumRDMARanks; ++ i) {
current_rdma_idx = (current_rdma_idx + 1) % kNumRDMARanks;
if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
continue;
......@@ -1382,23 +1394,36 @@ combine(int4* combined_x, float* combined_topk_weights,
}
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx);
// Copy data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
// Load data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
if (lane_id == 0) {
tma_store_wait();
tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);
// Copy source meta
if (lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
// Load source meta
if (lane_id == num_topk)
*reinterpret_cast<SourceMeta*>(tma_buffer + hidden_bytes) = ld_nc_global(src_meta + token_idx);
// Copy `topk_weights`
// Load `topk_weights`
if (lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
*reinterpret_cast<float*>(tma_buffer + hidden_bytes + sizeof(SourceMeta) + lane_id * sizeof(float)) = ld_nc_global(topk_weights + token_idx * num_topk + lane_id);
// Issue TMA store
tma_store_fence();
__syncwarp();
if (lane_id == 0)
tma_store_1d(tma_buffer, shifted_x_buffers, num_bytes_per_token, false);
}
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
}
// Move queue tail
tma_store_wait();
__syncwarp();
if (lane_id < kNumRDMARanks and is_lane_ready)
st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
......@@ -1406,9 +1431,7 @@ combine(int4* combined_x, float* combined_topk_weights,
} else {
// Combiners and coordinators
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
......@@ -1418,9 +1441,7 @@ combine(int4* combined_x, float* combined_topk_weights,
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
nvl_buffers[i] = buffer_ptrs[i];
auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_x = AsymBuffer<uint8_t>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
......@@ -1448,9 +1469,7 @@ combine(int4* combined_x, float* combined_topk_weights,
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough");
// Advance to the corresponding NVL buffer
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_bytes_per_token);
nvl_channel_head.advance(dst_rdma_rank);
nvl_channel_tail.advance(dst_rdma_rank);
......@@ -1513,9 +1532,9 @@ combine(int4* combined_x, float* combined_topk_weights,
// Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<float*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
......@@ -1533,9 +1552,9 @@ combine(int4* combined_x, float* combined_topk_weights,
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token);
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else {
......@@ -1593,8 +1612,8 @@ combine(int4* combined_x, float* combined_topk_weights,
__syncwarp();
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
......@@ -1612,7 +1631,7 @@ combine(int4* combined_x, float* combined_topk_weights,
} else {
// Coordinator
// Sync shared memory status
is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem();
is_forwarder_sm ? sync_forwarder_smem() : sync_rdma_receiver_smem();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0;
......@@ -1622,13 +1641,13 @@ combine(int4* combined_x, float* combined_topk_weights,
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps");
while (true) {
// Retired
if (is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
if (not is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break;
if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
if (is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break;
// Find minimum head for RDMA ranks
if (is_rdma_receiver_sm) {
if (not is_forwarder_sm) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
......@@ -1670,10 +1689,13 @@ void combine(cudaDataType_t type,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 16;
constexpr int kNumTMABytesPerWarp = 16384;
constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \
auto combine_func = low_latency_mode ? \
combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps>; \
combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerWarp> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(combine_func); \
LAUNCH_KERNEL(&cfg, combine_func, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \
......@@ -1688,12 +1710,12 @@ void combine(cudaDataType_t type,
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(type == CUDA_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream);
SETUP_LAUNCH_CONFIG(num_channels * 2, (num_forwarder_warps + 1) * 32, stream);
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
......
......@@ -231,9 +231,9 @@ class Buffer:
2: Config(Buffer.num_sms, 10, 256, 6, 128),
4: Config(Buffer.num_sms, 9, 256, 6, 128),
8: Config(Buffer.num_sms, 4, 256, 6, 128),
16: Config(Buffer.num_sms, 2, 288, 28, 128),
24: Config(Buffer.num_sms, 1, 288, 20, 128),
32: Config(Buffer.num_sms, 1, 288, 20, 128),
16: Config(Buffer.num_sms, 4, 288, 16, 128),
24: Config(Buffer.num_sms, 1, 288, 8, 128),
32: Config(Buffer.num_sms, 1, 288, 8, 128),
64: Config(Buffer.num_sms, 1, 288, 20, 128),
128: Config(Buffer.num_sms, 1, 560, 12, 128),
144: Config(Buffer.num_sms, 2, 720, 8, 128),
......
......@@ -209,7 +209,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for nvl_chunk_size in range(1, 13, 1):
for rdma_chunk_size in range(8, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment