Commit ab0afb04 authored by lishen's avatar lishen
Browse files

Merge branch 'normal_update' into 'main'

Normal update

See merge request dcutoolkit/deeplearing/DeepEP!28
parents 766b17b3 30aa7a87
......@@ -47,21 +47,25 @@ struct Config {
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms;
const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
// 计算每个nvl通信数据包的数据量
size_t num_single_nvl_bag_bytes =
hidden_bytes + // 数据缓冲区(Token Data)。存储从 RDMA 转发过来的 token 数据(x 张量)
#ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
internode::get_source_meta_bytes();
internode::get_source_meta_bytes() + // 源元数据缓冲区(Source Metadata)。存储每个 token 的源信息(哪个 RDMA rank 发送的)
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
kNumMaxScales * sizeof(float);
kNumMaxTopK * sizeof(int) + // TopK 索引缓冲区。存储每个 token 的 top-k 专家索引
kNumMaxTopK * sizeof(float) + // TopK 权重缓冲区。存储每个 token 的 top-k 专家权重
kNumMaxScales * sizeof(float); // Scale 缓冲区。存储每个 token 的量化缩放因子
// 计算每个 NVL channel 的控制信息所需的字节数,存储每个 NVL channel 的前缀索引信息,用于快速定位数据(nvl_channel_prefix_start、nvl_channel_prefix_end 等)
size_t num_single_nvl_control_bytes = (2 * num_rdma_ranks + 3) * sizeof(int);
// NVL 数据总的字节数
size_t num_bytes = (num_single_nvl_bag_bytes * num_max_nvl_chunked_recv_tokens + num_single_nvl_control_bytes) * num_channels * num_nvl_ranks;
// 128 字节对齐,匹配 GPU 缓存行大小,优化内存访问。
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
......@@ -79,22 +83,25 @@ struct Config {
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxScales * sizeof(float) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
// 计算每个rdma通信数据包的数据量
size_t num_single_rdma_bag_bytes =
hidden_bytes + // 数据缓冲区。存储实际的 token 数据(x 张量),对应代码中的 rdma_channel_data
internode::get_source_meta_bytes() + // 源元数据缓冲区。存储每个 token 的源信息(SourceMeta)
kNumMaxTopK * sizeof(int) + // 存储每个 token 的 top-k 专家索引。对应 topk_idx 数据
kNumMaxTopK * sizeof(float) + // 存储每个 token 的 top-k 专家权重。对应 topk_weights 数据
kNumMaxScales * sizeof(float) + // 存储每个 token 的缩放因子(x_scales)
sizeof(int4); // 预留空间用于内存对齐和未来扩展
// 计算每个 RDMA channel 的控制信息(起始/结束索引)所需的字节数,对应代码中的 rdma_channel_meta
size_t num_single_rdma_control_bytes = (NUM_MAX_NVL_PEERS * 2 + 4) * sizeof(int);
// RDMA 数据总的字节数
size_t num_bytes = (num_single_rdma_bag_bytes * num_max_rdma_chunked_recv_tokens + num_single_rdma_control_bytes) *
num_channels * num_rdma_ranks * 2;
// 128 字节对齐(缓存行对齐),优化内存访问性能
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
......
......@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
EP_HOST_ASSERT(num_rdma_bytes >= config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks));
// Just a barrier and clean flags
internode::cached_notify(
hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr,
......@@ -1205,6 +1206,7 @@ Buffer::internode_combine(
EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <=
config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);
EP_HOST_ASSERT(num_rdma_bytes >= config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks));
// Launch barrier and reset queue head and tail
internode::cached_notify(
......
......@@ -7,6 +7,10 @@
#ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings
// #ifdef __clang__
// #pragma clang diagnostic push
......@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_channels) {
// Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels};
}
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
int num_sms) {
int num_channels) {
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
"Invalid size of `SourceMeta`");
......@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(num_nvl_recv_buffer_tokens *
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_sms) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
num_nvl_ranks * num_channels) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,
};
}
......@@ -1230,13 +1236,13 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
// Iterate in reverse order
if (lane_id < num_rdma_ranks and warp_id < num_channels) {
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
if (lane_id < num_rdma_ranks) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx,
token_end_idx);
// NOTES: `1 << 25` is a heuristic large number
......@@ -1251,26 +1257,26 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
} else {
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
rdma_rank_prefix_sum != nullptr);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
constexpr int num_clean_sms = 2;
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
if (lane_id < NUM_MAX_NVL_PEERS ) {
for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
dst_rdma_rank += num_channels * 2 - num_clean_sms) {
dst_rdma_rank += num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL - num_clean_sms) {
// Iterate in reverse order
int token_start_idx =
warp_id == 0
channel_id == 0
? 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
int token_end_idx =
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift, token_end_idx += shift;
......@@ -1288,6 +1294,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
}
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
......@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = ::max(128, kWarpSize * num_channels);
const int num_threads = ::min(1024, ::max(128, kWarpSize * num_channels));
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta
......@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_channels * 2 > 2);
EP_HOST_ASSERT(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL > 2);
// Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, num_threads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE(
&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
......@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team);
}
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, bool kUseMLS, typename GetAddrFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
int num_max_recv_tokens,
const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads
......@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
......@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__ shmem_ctx_t ctx;
shmem_wg_ctx_create(&ctx);
#endif
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
......@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS, true>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
num_max_nvl_chunked_recv_tokens_per_rdma,
get_addr_fn, recv_tw_fn);
// Update head
if(lane_id < NUM_MAX_NVL_PEERS) {
......@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
......@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
syncwarp();
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast<int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx; };
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, false>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
num_max_rdma_chunked_recv_tokens,
get_addr_fn, recv_tw_fn);
}
// Retired
......@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
......
......@@ -207,7 +207,8 @@ class Buffer:
new_num_sms: the new number to be set.
"""
assert new_num_sms % 2 == 0, "The SM count must be even"
assert new_num_sms % 2 == 0, "The SM count must be new_num_sms % 2 == 0"
assert new_num_sms % 3 == 0, "The SM count must be new_num_sms % 3 == 0"
Buffer.num_sms = new_num_sms
@staticmethod
......
......@@ -2,10 +2,13 @@
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
......@@ -17,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
......@@ -2,10 +2,13 @@
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
......@@ -17,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
......@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
if not is_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
max_weights = recv_topk_weights.amax(dim=1, keepdim=True) # Shape: [Batch, 1]
recv_topk_weights = torch.where(recv_topk_idx == -1, max_weights, recv_topk_weights)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
......@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print('', flush=True)
......@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 45, 4):
for rdma_chunk_size in range(4, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'), suppress_kineto_output=True)
......@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 8, 1):
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True)
......@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes_ll = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
num_sms = 48
num_sms = 60
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
deep_ep.Buffer.set_num_sms(num_sms)
hidden_bytes = get_hidden_bytes(args)
num_nvl_bytes, num_rdma_bytes, num_rdma_bytes_norm = 0, 0, 0
......@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
if local_rank == 0:
print(f'{ref_hash=}')
print(f'ref_hash={ref_hash}')
print('', flush=True)
for j in range(20):
......
......@@ -244,7 +244,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank)
for i in (24, ):
for i in (60, ):
test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
......
......@@ -52,8 +52,9 @@ def test_main(num_tokens: int,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
if rank == 0:
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0
......@@ -144,7 +145,7 @@ def test_main(num_tokens: int,
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
if (enable_dispatch_ll_layered or enable_combine_overlap):
if enable_dispatch_ll_layered or enable_combine_overlap:
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
......@@ -179,7 +180,7 @@ def test_main(num_tokens: int,
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
if enable_combine_overlap:
block_m, threshold, num_sms = 64, 10, 3
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数??
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数
comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
for i in range(num_local_experts):
......
......@@ -8,12 +8,15 @@ export PYTHONPATH=$(pwd)
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=$(pwd)/tests_mpi/topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
# export NVSHMEM_SYMMETRIC_SIZE=10737418240
......@@ -145,7 +145,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
if not is_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
max_weights = recv_topk_weights.amax(dim=1, keepdim=True) # Shape: [Batch, 1]
recv_topk_weights = torch.where(recv_topk_idx == -1, max_weights, recv_topk_weights)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
......@@ -203,6 +204,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 45, 4):
for rdma_chunk_size in range(4, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'), suppress_kineto_output=True)
......@@ -235,6 +238,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 8, 1):
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True)
......@@ -272,8 +277,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes_ll = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
num_sms = 48
num_sms = 60
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
deep_ep.Buffer.set_num_sms(num_sms)
hidden_bytes = get_hidden_bytes(args)
num_nvl_bytes, num_rdma_bytes, num_rdma_bytes_norm = 0, 0, 0
......@@ -299,7 +305,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
if rank == 0:
print(f'{ref_hash=}')
print(f'ref_hash={ref_hash}')
print('', flush=True)
for j in range(20):
......
......@@ -119,7 +119,8 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
# Check `topk_weights`
recv_topk_weights_clone = recv_topk_weights.clone()
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
max_weights = recv_topk_weights.amax(dim=1, keepdim=True) # Shape: [Batch, 1]
recv_topk_weights = torch.where(recv_topk_idx == -1, max_weights, recv_topk_weights)
check_data(recv_topk_weights, rank_prefix_matrix)
# Test `num_worst_tokens != 0`
......@@ -251,7 +252,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank)
for i in (48, ):
for i in (60, ):
test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
......
......@@ -36,6 +36,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def ceil_div(a, b):
return (a + b - 1) // b
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
......@@ -44,11 +48,17 @@ def test_main(num_tokens: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
if rank == 0:
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
......@@ -86,6 +96,9 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True):
if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop
continue
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
......@@ -133,7 +146,12 @@ def test_main(num_tokens: int,
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
if enable_dispatch_ll_layered or enable_combine_overlap:
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
......@@ -150,6 +168,7 @@ def test_main(num_tokens: int,
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
......@@ -161,6 +180,28 @@ def test_main(num_tokens: int,
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
if enable_combine_overlap:
block_m, threshold, num_sms = 64, 10, 3
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数
comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
for i in range(num_local_experts):
vaild_num = ceil_div(packed_recv_count[i], block_m)
comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
packed_recv_count=packed_recv_count,
comp_signal=comp_signal,
block_m=block_m,
threshold=threshold,
num_sms=num_sms,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
else:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
......@@ -170,6 +211,7 @@ def test_main(num_tokens: int,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
......@@ -181,8 +223,10 @@ def test_main(num_tokens: int,
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
if rank == 0:
print('', flush=True)
print("deep_ep 全部正确性测试完成")
if enable_dispatch_ll_layered or enable_combine_overlap:
return hash_value
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
......@@ -252,9 +296,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
print(f"num_tokens, hidden, num_ranks, num_experts = {num_tokens}, {hidden}, {num_ranks}, {num_experts}")
enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
enable_combine_overlap = args.enable_combine_overlap
if enable_dispatch_ll_layered:
enable_combine_overlap = True
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered=enable_dispatch_ll_layered)
if rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
......@@ -263,7 +311,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
allow_mnnvl=args.allow_mnnvl,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap
)
print("deep_ep 初始化完成")
test_main(num_tokens,
hidden,
num_experts,
......@@ -273,6 +325,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=1)
do_pressure_test = args.pressure_test
......@@ -288,6 +342,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
......@@ -299,6 +355,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
......@@ -331,6 +389,10 @@ if __name__ == '__main__':
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
# 新版 sbo 需要的
parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')
args = parser.parse_args()
if args.world_size > args.num_processes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment