Commit 243eca85 authored by lishen01's avatar lishen01
Browse files

fix: 解决高吞吐的SM最大只能到48的问题,提升高吞吐的整体性能

parent 766b17b3
......@@ -47,21 +47,25 @@ struct Config {
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms;
const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
// 计算每个nvl通信数据包的数据量
size_t num_single_nvl_bag_bytes =
hidden_bytes + // 数据缓冲区(Token Data)。存储从 RDMA 转发过来的 token 数据(x 张量)
#ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
internode::get_source_meta_bytes();
internode::get_source_meta_bytes() + // 源元数据缓冲区(Source Metadata)。存储每个 token 的源信息(哪个 RDMA rank 发送的)
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
kNumMaxScales * sizeof(float);
kNumMaxTopK * sizeof(int) + // TopK 索引缓冲区。存储每个 token 的 top-k 专家索引
kNumMaxTopK * sizeof(float) + // TopK 权重缓冲区。存储每个 token 的 top-k 专家权重
kNumMaxScales * sizeof(float); // Scale 缓冲区。存储每个 token 的量化缩放因子
// 计算每个 NVL channel 的控制信息所需的字节数,存储每个 NVL channel 的前缀索引信息,用于快速定位数据(nvl_channel_prefix_start、nvl_channel_prefix_end 等)
size_t num_single_nvl_control_bytes = (2 * num_rdma_ranks + 3) * sizeof(int);
// NVL 数据总的字节数
size_t num_bytes = (num_single_nvl_bag_bytes * num_max_nvl_chunked_recv_tokens + num_single_nvl_control_bytes) * num_channels * num_nvl_ranks;
// 128 字节对齐,匹配 GPU 缓存行大小,优化内存访问。
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
......@@ -79,22 +83,25 @@ struct Config {
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxScales * sizeof(float) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
// 计算每个rdma通信数据包的数据量
size_t num_single_rdma_bag_bytes =
hidden_bytes + // 数据缓冲区。存储实际的 token 数据(x 张量),对应代码中的 rdma_channel_data
internode::get_source_meta_bytes() + // 源元数据缓冲区。存储每个 token 的源信息(SourceMeta)
kNumMaxTopK * sizeof(int) + // 存储每个 token 的 top-k 专家索引。对应 topk_idx 数据
kNumMaxTopK * sizeof(float) + // 存储每个 token 的 top-k 专家权重。对应 topk_weights 数据
kNumMaxScales * sizeof(float) + // 存储每个 token 的缩放因子(x_scales)
sizeof(int4); // 预留空间用于内存对齐和未来扩展
// 计算每个 RDMA channel 的控制信息(起始/结束索引)所需的字节数,对应代码中的 rdma_channel_meta
size_t num_single_rdma_control_bytes = (NUM_MAX_NVL_PEERS * 2 + 4) * sizeof(int);
// RDMA 数据总的字节数
size_t num_bytes = (num_single_rdma_bag_bytes * num_max_rdma_chunked_recv_tokens + num_single_rdma_control_bytes) *
num_channels * num_rdma_ranks * 2;
// 128 字节对齐(缓存行对齐),优化内存访问性能
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
......
......@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
EP_HOST_ASSERT(num_rdma_bytes >= config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks));
// Just a barrier and clean flags
internode::cached_notify(
hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr,
......@@ -1205,6 +1206,7 @@ Buffer::internode_combine(
EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <=
config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);
EP_HOST_ASSERT(num_rdma_bytes >= config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks));
// Launch barrier and reset queue head and tail
internode::cached_notify(
......
......@@ -7,6 +7,10 @@
#ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings
// #ifdef __clang__
// #pragma clang diagnostic push
......@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_channels) {
// Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels};
}
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
int num_sms) {
int num_channels) {
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
"Invalid size of `SourceMeta`");
......@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(num_nvl_recv_buffer_tokens *
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_sms) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
num_nvl_ranks * num_channels) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,
};
}
......@@ -1230,13 +1236,13 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
// Iterate in reverse order
if (lane_id < num_rdma_ranks and warp_id < num_channels) {
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
if (lane_id < num_rdma_ranks) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx,
token_end_idx);
// NOTES: `1 << 25` is a heuristic large number
......@@ -1251,26 +1257,26 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
} else {
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
rdma_rank_prefix_sum != nullptr);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
constexpr int num_clean_sms = 2;
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
if (lane_id < NUM_MAX_NVL_PEERS ) {
for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
dst_rdma_rank += num_channels * 2 - num_clean_sms) {
dst_rdma_rank += num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL - num_clean_sms) {
// Iterate in reverse order
int token_start_idx =
warp_id == 0
channel_id == 0
? 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
int token_end_idx =
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift, token_end_idx += shift;
......@@ -1288,6 +1294,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
}
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
......@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = ::max(128, kWarpSize * num_channels);
const int num_threads = ::min(1024, ::max(128, kWarpSize * num_channels));
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta
......@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_channels * 2 > 2);
EP_HOST_ASSERT(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL > 2);
// Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, num_threads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE(
&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
......@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team);
}
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, bool kUseMLS, typename GetAddrFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
int num_max_recv_tokens,
const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads
......@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
......@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
......@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS, true>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
num_max_nvl_chunked_recv_tokens_per_rdma,
get_addr_fn, recv_tw_fn);
// Update head
if(lane_id < NUM_MAX_NVL_PEERS) {
......@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
......@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
syncwarp();
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, false>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
num_max_rdma_chunked_recv_tokens,
get_addr_fn, recv_tw_fn);
}
// Retired
......@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
......
......@@ -207,7 +207,8 @@ class Buffer:
new_num_sms: the new number to be set.
"""
assert new_num_sms % 2 == 0, "The SM count must be even"
assert new_num_sms % 2 == 0, "The SM count must be new_num_sms % 2 == 0"
assert new_num_sms % 3 == 0, "The SM count must be new_num_sms % 3 == 0"
Buffer.num_sms = new_num_sms
@staticmethod
......
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
......
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
......
......@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
if not is_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
max_weights = recv_topk_weights.amax(dim=1, keepdim=True) # Shape: [Batch, 1]
recv_topk_weights = torch.where(recv_topk_idx == -1, max_weights, recv_topk_weights)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
......@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print('', flush=True)
......@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 45, 4):
for rdma_chunk_size in range(4, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'), suppress_kineto_output=True)
......@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 8, 1):
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True)
......@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes_ll = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
num_sms = 48
num_sms = 60
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
deep_ep.Buffer.set_num_sms(num_sms)
hidden_bytes = get_hidden_bytes(args)
num_nvl_bytes, num_rdma_bytes, num_rdma_bytes_norm = 0, 0, 0
......@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
if local_rank == 0:
print(f'{ref_hash=}')
print(f'ref_hash={ref_hash}')
print('', flush=True)
for j in range(20):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment