Unverified Commit f9c06bb0 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Use TMA to optimize combine forwarder. (#320)



* Remove an outdated todo

* Increase the number of combine forward warps.

* forwarder use TMA.

* Small fix

* Code lint

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent e6012370
#include <functional>
#include <optional>
#include "configs.cuh" #include "configs.cuh"
#include "buffer.cuh" #include "buffer.cuh"
#include "exception.cuh" #include "exception.cuh"
...@@ -408,7 +411,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -408,7 +411,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
auto hidden_bytes = hidden_int4 * sizeof(int4); auto hidden_bytes = hidden_int4 * sizeof(int4);
auto scale_bytes = num_scales * sizeof(float); auto scale_bytes = num_scales * sizeof(float);
// TODO: rename `num_bytes_per_rdma_token` after combine refactor
auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk); auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data = SymBuffer<uint8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_data = SymBuffer<uint8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
...@@ -1219,12 +1221,13 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1219,12 +1221,13 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
is_cached_dispatch, cpu_rdma_team); is_cached_dispatch, cpu_rdma_team);
} }
template <int kNumRanks, bool kMaybeWithBias, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn> template <int kNumRanks, bool kMaybeWithBias, typename dtype_t, int kMaxNumRanks, bool kUseTMA, int kNumStages, int kNumTMALoadBytes = 0, typename GetAddrFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx, __device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk, int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights, int4* combined_row, float* combined_topk_weights,
const int4* bias_0_int4, const int4* bias_1_int4, const int4* bias_0_int4, const int4* bias_1_int4,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) { int num_max_recv_tokens, const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn,
uint8_t* smem_ptr, uint32_t (&tma_phase)[kNumStages]) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads // Broadcast current heads
...@@ -1237,52 +1240,107 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, ...@@ -1237,52 +1240,107 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
topk_ranks[num_topk_ranks ++] = i; topk_ranks[num_topk_ranks ++] = i;
} }
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
EP_STATIC_ASSERT(not (kUseTMA and kMaybeWithBias), "TMA cannot be used by receiver warps");
EP_STATIC_ASSERT(kNumStages == 2, "Only support 2 stages now");
// Reduce data // Reduce data
#pragma unroll if constexpr (kUseTMA) {
for (int i = lane_id; i < hidden_int4; i += 32) { constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16;
// Read bias EP_DEVICE_ASSERT(hidden_int4 % 32 == 0);
// TODO: make it as a finer-grained template
int4 bias_0_value_int4, bias_1_value_int4;
if (kMaybeWithBias) {
bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0);
bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0);
}
// Read buffers auto tma_load_buffer = [=](const int& i, const int& j) -> int4* { return reinterpret_cast<int4*>(smem_ptr + i * kNumTMABufferBytesPerStage + j * kNumTMALoadBytes); };
// TODO: maybe too many registers here auto tma_store_buffer = [=](const int& i) -> int4* { return reinterpret_cast<int4*>(smem_ptr + i * kNumTMABufferBytesPerStage + NUM_MAX_NVL_PEERS * kNumTMALoadBytes); };
int4 recv_value_int4[kMaxNumRanks]; auto tma_mbarrier = [=](const int& i) -> uint64_t* { return reinterpret_cast<uint64_t*>(smem_ptr + i * kNumTMABufferBytesPerStage + (NUM_MAX_NVL_PEERS + 1) * kNumTMALoadBytes); };
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
// Clean // Prefetch
// Reduce bias if (lane_id < num_topk_ranks)
float values[kDtypePerInt4] = {0}; tma_load_1d(tma_load_buffer(0, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], 0), tma_mbarrier(0), kNumTMALoadBytes);
if (kMaybeWithBias) { mbarrier_arrive_and_expect_tx(tma_mbarrier(0), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0);
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4); __syncwarp();
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
for (int shifted = 0, iter = 0; shifted < hidden_int4; shifted += 32, iter += 1) {
const int stage_idx = iter % kNumStages;
const int next_stage_idx = (iter + 1) % kNumStages;
// Prefetch next stage
if (shifted + 32 < hidden_int4) {
if (lane_id < num_topk_ranks)
tma_load_1d(tma_load_buffer(next_stage_idx, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], shifted + 32), tma_mbarrier(next_stage_idx), kNumTMALoadBytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier(next_stage_idx), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0);
__syncwarp();
}
mbarrier_wait(tma_mbarrier(stage_idx), tma_phase[stage_idx]);
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(tma_load_buffer(stage_idx, j) + lane_id);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
tma_store_wait<kNumStages - 1>();
auto out_dtypes = reinterpret_cast<dtype_t*>(tma_store_buffer(stage_idx) + lane_id);
#pragma unroll #pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j) for (int j = 0; j < kDtypePerInt4; ++ j)
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]); out_dtypes[j] = static_cast<dtype_t>(values[j]);
tma_store_fence();
__syncwarp();
if (lane_id == 0)
tma_store_1d(tma_store_buffer(stage_idx), combined_row + shifted + lane_id, kNumTMALoadBytes);
__syncwarp();
} }
// Reduce all-to-all results // Flush all writes
tma_store_wait();
} else {
#pragma unroll #pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) { for (int i = lane_id; i < hidden_int4; i += 32) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]); // Read bias
// TODO: make it as a finer-grained template
int4 bias_0_value_int4, bias_1_value_int4;
if constexpr (kMaybeWithBias) {
bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0);
bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0);
}
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll #pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k) for (int j = 0; j < num_topk_ranks; ++ j)
values[k] += static_cast<float>(recv_value_dtypes[k]); recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));
}
// Clean
// Reduce bias
float values[kDtypePerInt4] = {0};
if constexpr (kMaybeWithBias) {
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);
}
// Cast back to `dtype_t` and write // Reduce all-to-all results
int4 out_int4; #pragma unroll
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4); for (int j = 0; j < num_topk_ranks; ++ j) {
#pragma unroll auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
for (int j = 0; j < kDtypePerInt4; ++ j) #pragma unroll
out_dtypes[j] = static_cast<dtype_t>(values[j]); for (int k = 0; k < kDtypePerInt4; ++ k)
st_na_global(combined_row + i, out_int4); values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
st_na_global(combined_row + i, out_int4);
}
} }
// Reduce `topk_weights` // Reduce `topk_weights`
...@@ -1301,7 +1359,8 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, ...@@ -1301,7 +1359,8 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
template<bool kLowLatencyMode, template<bool kLowLatencyMode,
int kNumRDMARanks, typename dtype_t, int kNumRDMARanks, typename dtype_t,
int kNumCombineForwarderWarps, int kNumCombineForwarderWarps,
int kNumTMABytesPerWarp, int kNumTMABytesPerSenderWarp,
int kNumTMABytesPerForwarderWarp,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks), int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
...@@ -1378,14 +1437,14 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1378,14 +1437,14 @@ combine(int4* combined_x, float* combined_topk_weights,
// TMA stuffs // TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerWarp; auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerSenderWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + hidden_bytes); auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + hidden_bytes);
uint32_t tma_phase = 0; uint32_t tma_phase = 0;
if (lane_id == 0) { if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1); mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared(); fence_view_async_shared();
fence_barrier_init(); fence_barrier_init();
EP_DEVICE_ASSERT(hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp); EP_DEVICE_ASSERT(hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerSenderWarp);
} }
__syncwarp(); __syncwarp();
...@@ -1525,6 +1584,23 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1525,6 +1584,23 @@ combine(int4* combined_x, float* combined_topk_weights,
}; };
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough");
// TMA stuffs
constexpr int kNumStages = 2;
constexpr int kNumTMALoadBytes = sizeof(int4) * 32;
constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16;
EP_STATIC_ASSERT(kNumTMABufferBytesPerStage * kNumStages <= kNumTMABytesPerForwarderWarp, "TMA buffer is not larger enough");
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto smem_ptr = smem_buffer + warp_id * kNumStages * kNumTMABufferBytesPerStage;
auto tma_mbarrier = [=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * kNumTMABufferBytesPerStage + kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1)); };
uint32_t tma_phase[kNumStages] = {0};
if (lane_id < kNumStages) {
mbarrier_init(tma_mbarrier(lane_id), 32);
fence_view_async_shared();
fence_barrier_init();
}
__syncwarp();
// Advance to the corresponding NVL buffer // Advance to the corresponding NVL buffer
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_bytes_per_token); nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_bytes_per_token);
nvl_channel_head.advance(dst_rdma_rank); nvl_channel_head.advance(dst_rdma_rank);
...@@ -1590,14 +1666,15 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1590,14 +1666,15 @@ combine(int4* combined_x, float* combined_topk_weights,
// Combine current token // Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx); }; auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<float*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx); }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<float*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0, combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS, true, kNumStages, kNumTMALoadBytes>(expected_head >= 0,
expected_head, lane_id, expected_head, lane_id,
hidden_int4, num_topk, hidden_int4, num_topk,
static_cast<int4*>(shifted), static_cast<int4*>(shifted),
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)), reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, get_addr_fn, recv_tw_fn,
smem_ptr, tma_phase);
// Update head // Update head
if (lane_id < NUM_MAX_NVL_PEERS) if (lane_id < NUM_MAX_NVL_PEERS)
...@@ -1669,16 +1746,18 @@ combine(int4* combined_x, float* combined_topk_weights, ...@@ -1669,16 +1746,18 @@ combine(int4* combined_x, float* combined_topk_weights,
__syncwarp(); __syncwarp();
// Combine current token // Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx);}; auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx;};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks>(expected_head >= 0, uint32_t dummy_tma_phases[2];
expected_head, lane_id, combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks, false, 2>(expected_head >= 0,
hidden_int4, num_topk, expected_head, lane_id,
combined_x + token_idx * hidden_int4, hidden_int4, num_topk,
combined_topk_weights + token_idx * num_topk, combined_x + token_idx * hidden_int4,
bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4, combined_topk_weights + token_idx * num_topk,
bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4, bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4,
num_max_rdma_chunked_recv_tokens, get_addr_fn, recv_tw_fn,
nullptr, dummy_tma_phases);
} }
// Retired // Retired
...@@ -1745,13 +1824,14 @@ void combine(cudaDataType_t type, ...@@ -1745,13 +1824,14 @@ void combine(cudaDataType_t type,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) { int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 16; constexpr int kNumCombineForwarderWarps = 24;
constexpr int kNumTMABytesPerWarp = 16384; constexpr int kNumTMABytesPerSenderWarp = 16384;
constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; constexpr int kNumTMABytesPerForwarderWarp = 9248;
constexpr int smem_size = std::max(kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS, kNumTMABytesPerForwarderWarp * kNumCombineForwarderWarps);
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \ #define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \
auto combine_func = low_latency_mode ? \ auto combine_func = low_latency_mode ? \
combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerWarp> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerWarp>; \ combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerSenderWarp, kNumTMABytesPerForwarderWarp> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps, kNumTMABytesPerSenderWarp, kNumTMABytesPerForwarderWarp>; \
SET_SHARED_MEMORY_FOR_TMA(combine_func); \ SET_SHARED_MEMORY_FOR_TMA(combine_func); \
LAUNCH_KERNEL(&cfg, combine_func, \ LAUNCH_KERNEL(&cfg, combine_func, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \ reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
...@@ -1767,6 +1847,7 @@ void combine(cudaDataType_t type, ...@@ -1767,6 +1847,7 @@ void combine(cudaDataType_t type,
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_rdma_ranks <= kNumCombineForwarderWarps);
EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment