Commit 243eca85 authored by lishen01's avatar lishen01
Browse files

fix: 解决高吞吐的SM最大只能到48的问题,提升高吞吐的整体性能

parent 766b17b3
...@@ -47,21 +47,25 @@ struct Config { ...@@ -47,21 +47,25 @@ struct Config {
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0); EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms; const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
size_t num_bytes = 0; // 计算每个nvl通信数据包的数据量
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); size_t num_single_nvl_bag_bytes =
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; hidden_bytes + // 数据缓冲区(Token Data)。存储从 RDMA 转发过来的 token 数据(x 张量)
#ifndef DISABLE_ROCSHMEM #ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes() + // 源元数据缓冲区(Source Metadata)。存储每个 token 的源信息(哪个 RDMA rank 发送的)
internode::get_source_meta_bytes();
#endif #endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * kNumMaxTopK * sizeof(int) + // TopK 索引缓冲区。存储每个 token 的 top-k 专家索引
sizeof(int64_t); kNumMaxTopK * sizeof(float) + // TopK 权重缓冲区。存储每个 token 的 top-k 专家权重
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * kNumMaxScales * sizeof(float); // Scale 缓冲区。存储每个 token 的量化缩放因子
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * // 计算每个 NVL channel 的控制信息所需的字节数,存储每个 NVL channel 的前缀索引信息,用于快速定位数据(nvl_channel_prefix_start、nvl_channel_prefix_end 等)
kNumMaxScales * sizeof(float); size_t num_single_nvl_control_bytes = (2 * num_rdma_ranks + 3) * sizeof(int);
// NVL 数据总的字节数
size_t num_bytes = (num_single_nvl_bag_bytes * num_max_nvl_chunked_recv_tokens + num_single_nvl_control_bytes) * num_channels * num_nvl_ranks;
// 128 字节对齐,匹配 GPU 缓存行大小,优化内存访问。
num_bytes = ((num_bytes + 127) / 128) * 128; num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes; return num_bytes;
} }
...@@ -79,22 +83,25 @@ struct Config { ...@@ -79,22 +83,25 @@ struct Config {
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0); EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms; const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
size_t num_bytes = 0; // 计算每个rdma通信数据包的数据量
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); size_t num_single_rdma_bag_bytes =
num_bytes += hidden_bytes + // 数据缓冲区。存储实际的 token 数据(x 张量),对应代码中的 rdma_channel_data
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; internode::get_source_meta_bytes() + // 源元数据缓冲区。存储每个 token 的源信息(SourceMeta)
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int) + // 存储每个 token 的 top-k 专家索引。对应 topk_idx 数据
internode::get_source_meta_bytes() * 2; kNumMaxTopK * sizeof(float) + // 存储每个 token 的 top-k 专家权重。对应 topk_weights 数据
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) + // 存储每个 token 的缩放因子(x_scales)
kNumMaxTopK * sizeof(int64_t) * 2; sizeof(int4); // 预留空间用于内存对齐和未来扩展
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2; // 计算每个 RDMA channel 的控制信息(起始/结束索引)所需的字节数,对应代码中的 rdma_channel_meta
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * size_t num_single_rdma_control_bytes = (NUM_MAX_NVL_PEERS * 2 + 4) * sizeof(int);
kNumMaxScales * sizeof(float) * 2;
num_bytes += // RDMA 数据总的字节数
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; size_t num_bytes = (num_single_rdma_bag_bytes * num_max_rdma_chunked_recv_tokens + num_single_rdma_control_bytes) *
num_channels * num_rdma_ranks * 2;
// 128 字节对齐(缓存行对齐),优化内存访问性能
num_bytes = ((num_bytes + 127) / 128) * 128; num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes; return num_bytes;
#else #else
......
...@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te ...@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
EP_HOST_ASSERT(num_rdma_bytes >= config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks));
// Just a barrier and clean flags // Just a barrier and clean flags
internode::cached_notify( internode::cached_notify(
hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr,
...@@ -1205,6 +1206,7 @@ Buffer::internode_combine( ...@@ -1205,6 +1206,7 @@ Buffer::internode_combine(
EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <=
config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);
EP_HOST_ASSERT(num_rdma_bytes >= config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks));
// Launch barrier and reset queue head and tail // Launch barrier and reset queue head and tail
internode::cached_notify( internode::cached_notify(
......
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
#ifndef DISABLE_ROCSHMEM #ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings // TODO: fix unroll warnings
// #ifdef __clang__ // #ifdef __clang__
// #pragma clang diagnostic push // #pragma clang diagnostic push
...@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_ ...@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__ __device__ __forceinline__ std::pair<int, int> __host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) { int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_channels) {
// Return `int32_t` offset and count to clean // Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int), num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms}; (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels};
} }
__host__ __device__ __forceinline__ std::pair<int, int> __host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
int num_sms) { int num_channels) {
// Return `int32_t` offset and to clean // Return `int32_t` offset and to clean
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
"Invalid size of `SourceMeta`"); "Invalid size of `SourceMeta`");
...@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(num_nvl_recv_buffer_tokens * (num_nvl_recv_buffer_tokens *
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_sms) / sizeof(int), num_nvl_ranks * num_channels) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms, num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,
}; };
} }
...@@ -1230,24 +1236,25 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1230,24 +1236,25 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (is_cached_dispatch) if (is_cached_dispatch)
return; return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize); EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
// Iterate in reverse order // Iterate in reverse order
if (lane_id < num_rdma_ranks and warp_id < num_channels) { for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
int token_start_idx, token_end_idx; if (lane_id < num_rdma_ranks) {
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, int token_start_idx, token_end_idx;
token_end_idx); get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx,
token_end_idx);
// NOTES: `1 << 25` is a heuristic large number // NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25; int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
auto current_head = auto current_head =
__ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
if (current_head < 0) { if (current_head < 0) {
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
} else { } else {
last_head = current_head; last_head = current_head;
}
} }
} }
} }
...@@ -1255,34 +1262,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i ...@@ -1255,34 +1262,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (is_cached_dispatch) if (is_cached_dispatch)
return; return;
EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
constexpr int num_clean_sms = 2; constexpr int num_clean_sms = 2;
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) { for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks; if (lane_id < NUM_MAX_NVL_PEERS ) {
dst_rdma_rank += num_channels * 2 - num_clean_sms) { for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
// Iterate in reverse order dst_rdma_rank += num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL - num_clean_sms) {
int token_start_idx = // Iterate in reverse order
warp_id == 0 int token_start_idx =
? 0 channel_id == 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; ? 0
int token_end_idx = : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; int token_end_idx =
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
token_start_idx += shift, token_end_idx += shift; int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift, token_end_idx += shift;
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25; // NOTES: `1 << 25` is a heuristic large number
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { int last_head = 1 << 25;
auto current_head = for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
__ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); auto current_head =
if (current_head < 0) { __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; if (current_head < 0) {
} else { combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
last_head = current_head; } else {
last_head = current_head;
}
} }
} }
} }
...@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank, int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) { bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = ::max(128, kWarpSize * num_channels); const int num_threads = ::min(1024, ::max(128, kWarpSize * num_channels));
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta // Get clean meta
...@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes); num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max()); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_channels * 2 > 2); EP_HOST_ASSERT(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL > 2);
// Launch kernel // Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>; auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, num_threads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE( LAUNCH_KERNEL_NON_COOPERATIVE(
&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
...@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team); cpu_rdma_team);
} }
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn> template <int kNumRanks, typename dtype_t, int kMaxNumRanks, bool kUseMLS, typename GetAddrFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx, __device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk, int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights, int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) { int num_max_recv_tokens,
const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads // Broadcast current heads
...@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, ...@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
int4 recv_value_int4[kMaxNumRanks]; int4 recv_value_int4[kMaxNumRanks];
#pragma unroll #pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i); recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));
// Reduce all-to-all results // Reduce all-to-all results
float values[kDtypePerInt4] = {0}; float values[kDtypePerInt4] = {0};
...@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__ shmem_ctx_t ctx; __shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx); shmem_wg_ctx_create(&ctx);
#endif #endif
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize; const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
...@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Combine current token // Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); }; auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0, combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS, true>(expected_head >= 0,
expected_head, lane_id, expected_head, lane_id,
hidden_int4, num_topk, hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted), reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)), reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); num_max_nvl_chunked_recv_tokens_per_rdma,
get_addr_fn, recv_tw_fn);
// Update head // Update head
if(lane_id < NUM_MAX_NVL_PEERS) { if(lane_id < NUM_MAX_NVL_PEERS) {
...@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int last_nvl_head[kNumRDMARanks] = {0}; int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) { while(true) {
// Retired // Retired
...@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
syncwarp(); syncwarp();
// Combine current token // Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);}; auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0, combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, false>(expected_head >= 0,
expected_head, lane_id, expected_head, lane_id,
hidden_int4, num_topk, hidden_int4, num_topk,
combined_x + token_idx * hidden_int4, combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk, combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); num_max_rdma_chunked_recv_tokens,
get_addr_fn, recv_tw_fn);
} }
// Retired // Retired
...@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_ ...@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int last_nvl_head[kNumRDMARanks] = {0}; int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) { while(true) {
// Retired // Retired
......
...@@ -207,7 +207,8 @@ class Buffer: ...@@ -207,7 +207,8 @@ class Buffer:
new_num_sms: the new number to be set. new_num_sms: the new number to be set.
""" """
assert new_num_sms % 2 == 0, "The SM count must be even" assert new_num_sms % 2 == 0, "The SM count must be new_num_sms % 2 == 0"
assert new_num_sms % 3 == 0, "The SM count must be new_num_sms % 3 == 0"
Buffer.num_sms = new_num_sms Buffer.num_sms = new_num_sms
@staticmethod @staticmethod
......
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM # # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
......
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM # # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
......
...@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights` # Check `topk_weights`
if not is_rand: if not is_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] max_weights = recv_topk_weights.amax(dim=1, keepdim=True) # Shape: [Batch, 1]
recv_topk_weights = torch.where(recv_topk_idx == -1, max_weights, recv_topk_weights)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs) # Test cached dispatch (must without top-k staffs)
...@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
if local_rank == 0: if local_rank == 0:
print(' passed', flush=True) print(' passed', flush=True)
if local_rank == 0: if local_rank == 0:
print('', flush=True) print('', flush=True)
...@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 45, 4): for nvl_chunk_size in range(4, 45, 4):
for rdma_chunk_size in range(4, 33, 4): for rdma_chunk_size in range(4, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config} tune_args = {'x': current_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'), suppress_kineto_output=True) t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'), suppress_kineto_output=True)
...@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time, best_results = 1e10, None best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 8, 1): for nvl_chunk_size in range(1, 8, 1):
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config} tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True) t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True)
...@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes_ll = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) num_rdma_bytes_ll = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
num_sms = 48 num_sms = 60
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
deep_ep.Buffer.set_num_sms(num_sms)
hidden_bytes = get_hidden_bytes(args) hidden_bytes = get_hidden_bytes(args)
num_nvl_bytes, num_rdma_bytes, num_rdma_bytes_norm = 0, 0, 0 num_nvl_bytes, num_rdma_bytes, num_rdma_bytes_norm = 0, 0, 0
...@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break break
if local_rank == 0: if local_rank == 0:
print(f'{ref_hash=}') print(f'ref_hash={ref_hash}')
print('', flush=True) print('', flush=True)
for j in range(20): for j in range(20):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment