Commit c2abba11 authored by lishen's avatar lishen
Browse files

Merge branch 'sbo_tmp' into 'main'

low-latency添加dispatch分层优化和combine gemm overlap

See merge request dcutoolkit/deeplearing/DeepEP!26
parents ea76f44e da6da7c3
...@@ -135,9 +135,11 @@ struct LowLatencyLayout { ...@@ -135,9 +135,11 @@ struct LowLatencyLayout {
} }
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) { int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
const int num_scales = hidden / QUANTIZATION_GROUPSIZE; const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // 计算结点数
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers // - 2 symmetric odd/even receive buffers
...@@ -152,7 +154,9 @@ struct LowLatencyLayout { ...@@ -152,7 +154,9 @@ struct LowLatencyLayout {
(quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐 (quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等 // 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + num_scales * sizeof(__hip_bfloat162); size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) +
(enable_dispatch_ll_layered ? 0 : // 即enable_combine_overlap==true,执行函数combine_sbo
num_scales * sizeof(__hip_bfloat162));
// Send buffer // Send buffer
size_t dispatch_send_buffer_bytes = size_t dispatch_send_buffer_bytes =
...@@ -176,6 +180,10 @@ struct LowLatencyLayout { ...@@ -176,6 +180,10 @@ struct LowLatencyLayout {
// Symmetric signaling buffers // Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t); size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
if (enable_dispatch_ll_layered) {
dispatch_recv_count_buffer_bytes +=
NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int);
}
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128); size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
...@@ -205,9 +213,11 @@ struct LowLatencyLayout { ...@@ -205,9 +213,11 @@ struct LowLatencyLayout {
}; };
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) { int num_ranks, int num_experts,
bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
auto num_bytes = auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size) LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size)
.total_bytes; .total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES; NUM_BUFFER_ALIGNMENT_BYTES;
......
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
namespace deep_ep { namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink) bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap)
: rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), : rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode), num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy), explicitly_destroy(explicitly_destroy),
enable_shrink(enable_shrink), enable_shrink(enable_shrink),
enable_dispatch_ll_layered(enable_dispatch_ll_layered),
enable_combine_overlap(enable_combine_overlap),
comm_stream(at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) { comm_stream(at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) {
// Metadata memory // Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
...@@ -25,6 +28,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ ...@@ -25,6 +28,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *); int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *);
EP_HOST_ASSERT(enable_shrink == false); EP_HOST_ASSERT(enable_shrink == false);
if (enable_dispatch_ll_layered)
EP_HOST_ASSERT(enable_combine_overlap == true);
// Common checks // Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
...@@ -1274,7 +1279,8 @@ Buffer::internode_combine( ...@@ -1274,7 +1279,8 @@ Buffer::internode_combine(
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts, int quant_group_size) { void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts, int quant_group_size) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size);
auto clean_meta_0 = layout.buffers[0].clean_meta(); auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta(); auto clean_meta_1 = layout.buffers[1].clean_meta();
...@@ -1311,7 +1317,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1311,7 +1317,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto num_local_experts = num_experts / num_ranks; auto num_local_experts = num_experts / num_ranks;
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered, quant_group_size);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
...@@ -1336,7 +1342,16 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1336,7 +1342,16 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
} }
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype)); auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::Dtype dtype = torch::kInt32;
if(enable_dispatch_ll_layered or enable_combine_overlap){
dtype = torch::kInt64;
}
auto packed_recv_src_info = torch::empty(
{num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(dtype).device(torch::kCUDA)
);
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
...@@ -1371,52 +1386,109 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1371,52 +1386,109 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
} }
// Kernel launch if(!enable_dispatch_ll_layered){
auto next_clean_meta = next_buffer.clean_meta(); // Kernel launch
auto launcher = [=](int phases) { auto next_clean_meta = next_buffer.clean_meta();
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, auto launcher = [=](int phases) {
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(), internode_ll::dispatch(
packed_recv_count.data_ptr<int>(), packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
global_atomic_counter.data_ptr<int>(), packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, packed_recv_count.data_ptr<int>(),
buffer.dispatch_rdma_send_buffer, global_atomic_counter.data_ptr<int>(),
x.data_ptr(), topk_idx.data_ptr<int64_t>(), buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
next_clean_meta.first, next_clean_meta.second, buffer.dispatch_rdma_send_buffer,
num_tokens, hidden, num_max_dispatch_tokens_per_rank, x.data_ptr(), topk_idx.data_ptr<int64_t>(),
num_topk, num_experts, rank, num_ranks, next_clean_meta.first, next_clean_meta.second,
quant_type, quant_group_size, fp8_round_scale, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
workspace, num_device_sms, launch_stream, phases); num_topk, num_experts, rank, num_ranks,
}; quant_type, quant_group_size, fp8_round_scale,
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); workspace, num_device_sms, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Wait streams // Receiver callback
std::optional<EventHandle> event; std::optional<std::function<void()>> recv_hook = std::nullopt;
if (async) { if (return_recv_hook)
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream); // Return values
} else if (not return_recv_hook) { return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
stream_wait(compute_stream, launch_stream); } else {
} // Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch_ll_layered(
!enable_dispatch_ll_layered,
packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int64_t>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
global_atomic_counter.data_ptr<int>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback // Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt; std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook) if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
}
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
} }
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) { const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// combine overlap checks
EP_HOST_ASSERT((!enable_combine_overlap || return_recv_hook) and "Overlap mode requires return_recv_hook=True"); // 启用 overlap 时, 必须 hook = True
EP_HOST_ASSERT((!enable_combine_overlap || packed_recv_count.has_value()) && "Overlap mode requires packed_recv_count has value");
EP_HOST_ASSERT((!enable_combine_overlap || comp_signal.has_value()) && "Overlap mode requires comp_signal has value");
EP_HOST_ASSERT((!enable_combine_overlap || block_m != -1) && "Overlap mode requires block_m != -1");
EP_HOST_ASSERT((!enable_combine_overlap || threshold != -1) && "Overlap mode requires threshold != -1");
EP_HOST_ASSERT((!enable_combine_overlap || num_sms != -1) && "Overlap mode requires num_sms != -1");
if (comp_signal.has_value()) {
EP_HOST_ASSERT(comp_signal->dim() == 1 and comp_signal->is_contiguous());
EP_HOST_ASSERT(comp_signal->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(comp_signal->size(0) == num_experts / num_ranks * ((num_ranks * num_max_dispatch_tokens_per_rank + 63) / 64));
}
// Tensor checks // Tensor checks
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
...@@ -1430,7 +1502,12 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1430,7 +1502,12 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); if (!enable_dispatch_ll_layered && !enable_combine_overlap) {
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
} else {
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt64 and x.size(0) == src_info.size(0));
}
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
...@@ -1446,7 +1523,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1446,7 +1523,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
auto num_combined_tokens = static_cast<int>(topk_weights.size(0)); auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
...@@ -1472,44 +1549,91 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1472,44 +1549,91 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Kernel launch // Kernel launch
auto next_clean_meta = next_buffer.clean_meta(); auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) { if(!enable_combine_overlap) {
internode_ll::combine(combined_x.data_ptr(), auto launcher = [=](int phases) {
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, internode_ll::combine(
buffer.combine_rdma_send_buffer, combined_x.data_ptr(),
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(), buffer.combine_rdma_send_buffer,
global_atomic_counter.data_ptr<int>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr, src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second, global_atomic_counter.data_ptr<int>(),
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
num_topk, num_experts, rank, num_ranks, next_clean_meta.first, next_clean_meta.second,
use_logfmt, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
workspace, num_device_sms, launch_stream, num_topk, num_experts, rank, num_ranks,
phases, zero_copy); use_logfmt,
}; workspace, num_device_sms, launch_stream,
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Wait streams // Receiver callback
std::optional<EventHandle> event; std::optional<std::function<void()>> recv_hook = std::nullopt;
if (async) { if (return_recv_hook)
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream); // Return values
} else if (not return_recv_hook) { return {combined_x, event, recv_hook};
stream_wait(compute_stream, launch_stream); } else {
} auto launcher = [=](int phases) {
internode_ll::combine_sbo(
combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer,
buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int64_t>(), layout_range.data_ptr<int64_t>(),
/* ll_layered 新增参数 */
!enable_dispatch_ll_layered,
/* overlap 新增参数 */
packed_recv_count.has_value() ? packed_recv_count->data_ptr<int>() : nullptr,
comp_signal.has_value() ? comp_signal->data_ptr<int>() : nullptr,
block_m, threshold, num_sms,
/* 辅助tensor */
global_atomic_counter.data_ptr<int>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, num_device_sms, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback // Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt; std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook) if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values // Return values
return {combined_x, event, recv_hook}; return {combined_x, event, recv_hook};
}
} }
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16; auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
...@@ -1540,7 +1664,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1540,7 +1664,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer") pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool>()) .def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool, bool>())
.def("is_available", &deep_ep::Buffer::is_available) .def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
......
...@@ -35,6 +35,8 @@ private: ...@@ -35,6 +35,8 @@ private:
// Shrink mode buffer // Shrink mode buffer
bool enable_shrink = false; bool enable_shrink = false;
bool enable_dispatch_ll_layered = false;
bool enable_combine_overlap = false;
int* mask_buffer_ptr = nullptr; int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr; int* sync_buffer_ptr = nullptr;
...@@ -77,7 +79,8 @@ private: ...@@ -77,7 +79,8 @@ private:
public: public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink); bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap);
~Buffer() noexcept(false); ~Buffer() noexcept(false);
...@@ -183,6 +186,9 @@ public: ...@@ -183,6 +186,9 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool use_logfmt,
......
...@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int quant_type, int group_size, bool fp8_round_scale, int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases); void* workspace, int num_device_sms, hipStream_t stream, int phases);
void dispatch_ll_layered(bool dispatch_ll_dispatch_opt,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
...@@ -163,6 +177,24 @@ void combine(void* combined_x, ...@@ -163,6 +177,24 @@ void combine(void* combined_x,
void* workspace, int num_device_sms, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy); int phases, bool zero_copy);
void combine_sbo(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap 新增控制参数
bool disable_ll_layered,
int* packed_recv_count, int* comp_signal,
int block_m, int threshold, int num_sms,
// 同步与统计参数
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
// 维度与配置参数
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
// 系统资源与执行参数
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll } // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
...@@ -1086,6 +1086,1041 @@ void combine(void* combined_x, ...@@ -1086,6 +1086,1041 @@ void combine(void* combined_x,
#undef COMBINE_LAUNCH_CASE #undef COMBINE_LAUNCH_CASE
} }
template <int kHidden, int kQuantType=0, int kQuantGroupSize=0, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch_ll_layered(
bool disable_ll_layered,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool fp8_round_scale, int phases) {
// 定义量化类型的枚举
enum class QuantType {
None = 0, // 不进行量化
Int8 = 1, // 采用 Int8 量化
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2
};
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_sms = static_cast<int>(gridDim.x);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
char* rdma_recv_x_cahr_ptr = reinterpret_cast<char*>(rdma_recv_x);
const auto num_nvl_ranks = NUM_MAX_NVL_PEERS;
const auto num_nodes = num_ranks / num_nvl_ranks;
int* data_ready_counter = reinterpret_cast<int*>(rdma_recv_count + num_experts);
int* data_ready_send_buffer =
data_ready_counter + num_nodes * num_max_dispatch_tokens_per_rank * num_nvl_ranks;
int* next_clean_data_ready_counter = reinterpret_cast<int*>(next_clean + num_experts);
if (!disable_ll_layered) {
if (thread_id < num_nvl_ranks) {
__hip_atomic_store(data_ready_send_buffer + thread_id, 2, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
}
__syncthreads();
// May extract UE8M0 from the scales
constexpr bool kUseQuant8Bit = kQuantType > 0;
constexpr bool kUseUE8M0 = kQuantType == 3; // QuantType::FP8_UE8M0
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs
constexpr int kNumPerChannels = QUANTIZATION_GROUPSIZE;
constexpr int kNumScales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseQuant8Bit ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
const size_t num_bytes_per_meta = sizeof(int4);
const size_t num_bytes_per_data = (kUseQuant8Bit ? (kHidden + (kQuantGroupSize == 0 ? 4 : kNumScales) * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
const size_t num_bytes_per_msg = num_bytes_per_meta + num_bytes_per_data;
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
char* rdma_recv_x_meta = rdma_recv_x_cahr_ptr;
char* rdma_recv_x_data = rdma_recv_x_cahr_ptr + num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_meta;
// Expert counts
__shared__ int shared_num_tokens_sent_per_expert[kMaxNumWarps];
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = num_warps * kWarpSize;
constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// Overlap top-k index read and source token index write
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// 用于记录per-channel量化的amax
__shared__ float channel_amaxf[kNumScales];
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
if (thread_id < kNumScales) {
channel_amaxf[thread_id] = 0.0;
}
__syncthreads();
}
// FP8 cast
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
auto int4_value = __ldg(x_int4 + i);
if constexpr(kUseQuant8Bit) {
// Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = 0.0, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j]));
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<kNumThreadPerGroup>(amax);
const int scale_offset = i * kNumElemsPerRead / QUANTIZATION_GROUPSIZE;
if constexpr(kQuantGroupSize == 0) {
// 记录每128个数的最大值
channel_amaxf[scale_offset] = fmaxf(amax, channel_amaxf[scale_offset]);
} else {
calculate_quant8bit_scales<kQuantType>(amax, scale, scale_inv, fp8_round_scale);
if (lane_id % kNumThreadPerGroup == 0)
rdma_x_scales[scale_offset] = scale_inv;
// Cast into send buffer
vec_t int2_value;
pack_quantized_values<kQuantType, kNumElemsPerRead>(fp32_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
} else {
// Reinterpret-cast is for C++14 compatibility
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
}
}
__syncthreads();
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
float amax_per_token = 0.0;
// 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id;
float tmp_amaxf = 0;
if(src_idx < kNumScales) {
tmp_amaxf = channel_amaxf[src_idx];
}
tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
channel_amaxf[0] = fmaxf(tmp_amaxf, channel_amaxf[0]);
__syncthreads();
}
amax_per_token = channel_amaxf[0];
// 根据最大值计算scale
float scale, scale_inv;
calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv, fp8_round_scale);
if (thread_id == 0) {
rdma_x_scales[0] = scale_inv;
}
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
auto int4_value = __ldg(x_int4 + i);
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
// Cast into send buffer
vec_t int2_value;
pack_quantized_values<kQuantType, kNumElemsPerRead>(bf16_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
__syncthreads();
}
// Issue IBGDA sends
if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
slot_idx = shfl_sync(slot_idx, 0);
const auto dst_rank = dst_expert_idx / num_local_experts;
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
if(!disable_ll_layered){
int send_node_id = dst_expert_idx / num_local_experts / num_nvl_ranks;
auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks +
rank % num_nvl_ranks; // send data to same gpu_device_id_rank(same-rail rdma traffic)
auto real_dst_expert_id = real_write_dst_rank * num_local_experts + dst_expert_local_idx;
auto tmp_dst_expert_id = lane_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id)) : -1;
auto tmp_dst_node_id = tmp_dst_expert_id >= 0 ? tmp_dst_expert_id / num_local_experts / num_nvl_ranks : -1;
for (int i = 0; i < warp_id; ++i) {
auto dst_node_id = shfl_sync(tmp_dst_node_id, i); // broadcast
if (dst_node_id == send_node_id) { // whether to send repeatedly
send_node_id = -1;
break;
}
}
if (send_node_id != -1) {
// ======================================= token data ==========================================
int* src_data_ptr = rdma_x_src_idx + 4;
char* dst_data_ptr = rdma_recv_x_data +
(rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data +
token_idx * num_bytes_per_data;
const auto p2p_data_ptr = internode::shmem_get_p2p_ptr((void*)(dst_data_ptr), rank, real_write_dst_rank);
if (!p2p_data_ptr) {
internode_ll_putmem_nbi(
reinterpret_cast<void*>(dst_data_ptr), reinterpret_cast<void*>(src_data_ptr),
num_ranks, real_write_dst_rank, dst_expert_local_idx, num_bytes_per_data);
} else {
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_data_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_data_ptr);
UNROLLED_WARP_COPY_LL(8, lane_id, num_bytes_per_data / sizeof(int4), dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// ======================================== token data flag =======================================
uint64_t src_data_flag_ptr = reinterpret_cast<uint64_t>(data_ready_send_buffer);
const auto data_ready_counter_ptr = reinterpret_cast<uint64_t>(data_ready_counter) +
(rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks * sizeof(int) +
token_idx * num_nvl_ranks * sizeof(int);
uint64_t data_ready_counter_p2p_ptr = internode::shmem_get_p2p_ptr((void*)(data_ready_counter_ptr), rank, real_write_dst_rank);
if (data_ready_counter_p2p_ptr == 0) {
// internode::shmemx_int8_put_nbi_warp_refactoring(
// reinterpret_cast<signed char*>(data_ready_counter_ptr), reinterpret_cast<signed char*>(src_data_flag_ptr),
// num_nvl_ranks * sizeof(int), num_ranks + dst_expert_local_idx * num_ranks + real_write_dst_rank, rank, real_write_dst_rank, true);
internode_ll_putmem_nbi(
reinterpret_cast<void*>(data_ready_counter_ptr), reinterpret_cast<void*>(src_data_flag_ptr),
num_ranks, real_write_dst_rank, dst_expert_local_idx, num_nvl_ranks * sizeof(int));
} else {
int* dst_int_ptr = reinterpret_cast<int*>(data_ready_counter_p2p_ptr);
if(lane_id < num_nvl_ranks){
__hip_atomic_store(dst_int_ptr + lane_id, 2, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
}
}
// ========================= meta data=============================
const auto src_meta_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
const auto dst_meta_ptr = reinterpret_cast<uint64_t>(rdma_recv_x_meta) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta +
slot_idx * num_bytes_per_meta;
uint64_t p2p_meta_ptr = internode::shmem_get_p2p_ptr((void*)(dst_meta_ptr), rank, dst_rank);
if (!p2p_meta_ptr) {
// internode::shmemx_int8_put_nbi_warp_refactoring(
// reinterpret_cast<signed char*>(dst_meta_ptr), reinterpret_cast<signed char*>(src_meta_ptr),
// num_bytes_per_meta, num_ranks + dst_expert_local_idx * num_ranks + dst_rank, rank, dst_rank, true);
internode_ll_putmem_nbi(
reinterpret_cast<void*>(dst_meta_ptr), reinterpret_cast<void*>(src_meta_ptr),
num_ranks, dst_rank, dst_expert_local_idx, num_bytes_per_meta);
} else {
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_meta_ptr);
int4* dst_int4_ptr = reinterpret_cast<int4*>(p2p_meta_ptr);
if(lane_id==0){
dst_int4_ptr[0] = src_int4_ptr[0];
}
}
syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + real_dst_expert_id, 1) : 0;
} else {
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_putmem_nbi((void*)dst_ptr, (void*)src_ptr,
num_ranks, dst_rank, dst_expert_local_idx,
num_bytes_per_msg);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// Increase counter after finishing
syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
}
}
}
if (warp_id == num_warps - 1) {
// EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
if (disable_ll_layered) {
// The first SM is also responsible for checking QPs
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
next_clean[i] = 0;
// Notify before executing `int_p`
syncwarp();
#pragma unroll
for (int i = lane_id; i < num_experts; i += kWarpSize)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kMaxNumWarps] = {0};
int waiting_flag[kMaxNumWarps] = {0};
const auto expert_begin_idx = sm_id * num_warp_groups;
const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
// Per lane count
#pragma unroll 8
for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx] ++;
if (!disable_ll_layered) {
if (idx < 0)
continue;
const auto dst_rank = idx / num_local_experts;
const auto dst_expert_local_idx = idx % num_local_experts;
auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks + rank % num_nvl_ranks;
auto real_dst_expert_id = real_write_dst_rank * num_local_experts + dst_expert_local_idx;
if (real_dst_expert_id >= expert_begin_idx and real_dst_expert_id < expert_end_idx)
waiting_flag[real_dst_expert_id - expert_begin_idx] ++;
}
}
// Warp reduce
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
auto waiting_flag_sum = 0;
if (!disable_ll_layered) { // only open ll dispatch opt, should do
waiting_flag_sum = warp_reduce_sum(waiting_flag[i - expert_begin_idx]);
}
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - waiting_flag_sum - sum);
}
}
}
if (!disable_ll_layered and sm_id == num_sms - 1) {
// The first SM is also responsible for cleaning the next buffer
for (int i = thread_id; i < num_experts; i += blockDim.x) // clean for combine
next_clean[i] = 0;
// clean data ready flag
for (int i = thread_id; i < num_max_dispatch_tokens_per_rank * num_ranks; i += blockDim.x) {
int token_idx = i / num_ranks;
int rank_id = i % num_ranks;
auto node_id = rank_id / num_nvl_ranks;
auto nvl_rank_id = rank_id % num_nvl_ranks;
auto* data_ready_flag_ptr = reinterpret_cast<int*>(next_clean_data_ready_counter) +
node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + token_idx * num_nvl_ranks + rank % num_nvl_ranks;
EP_DEVICE_ASSERT(data_ready_flag_ptr - next_clean_data_ready_counter <
num_max_dispatch_tokens_per_rank * num_nodes * num_nvl_ranks * sizeof(int));
const auto data_ready_p2p_src_ptr =
internode::shmem_get_p2p_ptr((void*)(data_ready_flag_ptr), rank, rank / num_nvl_ranks * num_nvl_ranks + nvl_rank_id);
reinterpret_cast<int*>(data_ready_p2p_src_ptr)[0] = 0;
}
__syncthreads();
#pragma unroll
for (int i = thread_id; i < num_experts; i += blockDim.x)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}
__syncthreads();
// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
auto dst_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_long_atomic_add(dst_ptr, -num_tokens_sent - 1,
num_ranks, dst_rank, dst_expert_local_idx);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release(reinterpret_cast<int *>(p2p_ptr), -num_tokens_sent - 1);
}
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
// Clean `packed_recv_count`
if (dst_rank == 0)
packed_recv_count[dst_expert_local_idx] = 0;
}
syncwarp();
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
if (phases & LOW_LATENCY_SEND_PHASE){
grid_barrier(global_atomic_counter, num_sms);
}
// 16 is the max possible number of warps in AMD GPUs
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Receiving and packing
if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
uint8_t* rdma_recv_x_uint8 = nullptr;
if (!disable_ll_layered) {
rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x_meta) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta;
}
if (disable_ll_layered) {
rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
}
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
(kQuantGroupSize == 0 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kMaxNumWarps], shared_recv_token_begin_idx[kMaxNumWarps];
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int num_recv_tokens, recv_token_begin_idx;
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 1 and lane_id == 0) {
while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
}
// no needs to reset because there is no iteration
if (lane_id == 0){
volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
}
syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
const auto real_read_src_rank = src_rank % num_nvl_ranks + rank / num_nvl_ranks * num_nvl_ranks;
// Copy tokens
EP_STATIC_ASSERT(kNumScales <= 64, "Invalid hidden size");
for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
int4* src_data = nullptr;
if (!disable_ll_layered) {
int* src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_meta);
int src_token_idx = __builtin_nontemporal_load(src_src_idx);
if (lane_id == 0) {
recv_src_info[recv_token_begin_idx + i] = pack2<int, int64_t>(src_token_idx, src_rank);
}
const auto data_ready_flag_src_ptr = data_ready_counter +
(src_rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks +
src_token_idx * num_nvl_ranks +
rank % num_nvl_ranks;
const auto src_data_ready_flag_p2p_ptr =
reinterpret_cast<int*>(internode::shmem_get_p2p_ptr((void*)(data_ready_flag_src_ptr), rank, real_read_src_rank));
if (lane_id == 0) {
int tmp = 0;
auto start_time = clock64();
bool flag_get = false;
while (tmp != 2) {
tmp = __hip_atomic_load(src_data_ready_flag_p2p_ptr, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_SYSTEM);
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
printf(
"DeepEP ll dispatch recv data timeout, src_rank:%d, dst_rank: %d, real_read_src_rank:%d,src_token_idx:%d "
"dst RDMA lane: %d, num_recv_tokens: %d\n",
src_rank,
rank,
real_read_src_rank,
src_token_idx,
lane_id,
num_recv_tokens
);
break;
}
}
}
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_recv_x_data) +
(src_rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data
+ src_token_idx * num_bytes_per_data;
uint64_t src_ptr_p2p = internode::shmem_get_p2p_ptr((void*)(src_ptr), rank, real_read_src_rank);
src_data = reinterpret_cast<int4*>(src_ptr_p2p);
}
if (disable_ll_layered) {
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
int src_token_idx = __builtin_nontemporal_load(src_src_idx);
if (lane_id == 0)
// 加入 源rank 信息
recv_src_info[recv_token_begin_idx + i] = pack2<int, int64_t>(src_token_idx, src_rank);
syncwarp();
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
}
const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales
if constexpr(kUseQuant8Bit) {
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if constexpr(kQuantGroupSize == 0) {
if (lane_id == 0) {
recv_x_scales[token_idx] = ld_nc_global(src_scales);
}
} else {
if (lane_id < kNumScales) {
const auto pack_idx = lane_id / num_elems_per_pack;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + kWarpSize < kNumScales) {
const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
}
}
}
}
}
void dispatch_ll_layered(bool dispatch_ll_dispatch_opt,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = kMaxNumWarps / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = ceil_div(num_experts, num_warp_groups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// 限制groupsize的大小
EP_HOST_ASSERT(quant_group_size == 0 || quant_group_size == 128);
/*量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
*/
#define DISPATCH_LL_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch_ll_layered<hidden, 0, 0, kMaxNumWarps>; \
if (quant_group_size == 0) { \
switch (quant_type) { \
case 1: dispatch_func = dispatch_ll_layered<hidden, 1, 0, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch_ll_layered<hidden, 2, 0, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch_ll_layered<hidden, 3, 0, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch_ll_layered<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch_ll_layered<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch_ll_layered<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch_ll_layered<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch_ll_layered<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, dispatch_ll_dispatch_opt, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LL_LAUNCH_CASE);
#undef DISPATCH_LL_LAUNCH_CASE
}
/*
combine 启用 overlop 后的实现
*/
template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
combine_sbo(bool disable_ll_layered,
void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap specific parameters
int* packed_recv_count, int* comp_signal, int block_m, int threshold,
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
int* atomic_clean_flag, int* atomic_finish_counter_per_expert,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) {
// 假设 启用 3 个block
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks; // 16
const auto warp_group_id = warp_id / num_warps_per_group; // 0 0 0 ... 0
const auto sub_warp_id = warp_id % num_warps_per_group; // 0 1 2 ... 15
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; // 这意味着 一次 并行处理 3个专家 0 1 2
int* next_clean_data_ready_counter = reinterpret_cast<int*>(next_clean + num_experts);
const auto num_nvl_ranks = NUM_MAX_NVL_PEERS;
const auto num_nodes = num_ranks / num_nvl_ranks;
// hidden_bf16_int4: bf16 的 token 包含多少个 int4
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// Shared between warps in sms for overlap mode, where each sm only has one warp group
__shared__ volatile int shared_vaild_signal_prefix_sum[40]; // 用于统计 本地专家 有效信号 的 前缀和
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV;
if (!disable_ll_layered and sm_id == num_sms - 1) {
#pragma unroll
for (int i = thread_id; i < num_experts; i += num_threads)
next_clean[i] = 0;
// clean data ready flag
for (int i = thread_id; i < num_max_dispatch_tokens_per_rank * num_ranks; i += num_threads) {
int token_idx = i / num_ranks;
int rank_id = i % num_ranks;
{
auto node_id = rank_id / num_nvl_ranks;
auto nvl_rank_id = rank_id % num_nvl_ranks;
auto* data_ready_flag_ptr = reinterpret_cast<int*>(next_clean_data_ready_counter) +
node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + token_idx * num_nvl_ranks + rank % num_nvl_ranks;
EP_DEVICE_ASSERT(data_ready_flag_ptr - next_clean_data_ready_counter <
num_max_dispatch_tokens_per_rank * num_nodes * num_nvl_ranks * sizeof(int));
const auto data_ready_p2p_src_ptr =
internode::shmem_get_p2p_ptr((void*)(data_ready_flag_ptr), rank, rank / num_nvl_ranks * num_nvl_ranks + nvl_rank_id);
reinterpret_cast<int*>(data_ready_p2p_src_ptr)[0] = 0;
}
}
// Notify before executing `int_p`
__syncthreads();
if (thread_id == 0)
atomic_add_release_global(atomic_clean_flag, num_experts);
}
if (disable_ll_layered) {
// Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
next_clean[i] = 0;
// Notify before executing `int_p`
syncwarp();
if (lane_id == 0)
atomic_add_release_global(atomic_clean_flag, num_experts);
}
}
__syncthreads();
// ========================================
// shared_vaild_signal_sum: 本地专家的总信号量
// shared_local_expert_idx: 共享内存中的 本地专家索引。初始置为 0 , 表明 当前 block 当前在 处理的 本地专家索引
__shared__ int shared_vaild_signal_sum, shared_local_expert_idx;
// 计算每个 本地专家 有效信号 计数 的 前缀和,即使没有 token, 也算作一个 任务
if (sub_warp_id == 0 and lane_id == 0) { // 0号 warp 的 0号线程 执行下述操作
shared_vaild_signal_prefix_sum[0] = (packed_recv_count[0] == 0 ? 1 : ceil_div(packed_recv_count[0], block_m));
shared_local_expert_idx = 0; // 共享内存中 本地专家索引 置为 0
for (int i = 1; i < num_local_experts; i++) {
shared_vaild_signal_prefix_sum[i] =
shared_vaild_signal_prefix_sum[i - 1] + (packed_recv_count[i] == 0 ? 1 : ceil_div(packed_recv_count[i], block_m));
}
shared_vaild_signal_sum = shared_vaild_signal_prefix_sum[num_local_experts - 1];
}
__syncthreads(); // 等待前缀和 统计完成 16个 warp 同步等待
// 每个 block 负责一个 处理信号,并循环处理到 最后
for (int vaild_signal_idx = sm_id; vaild_signal_idx < shared_vaild_signal_sum; vaild_signal_idx += num_sms) {
// ====================== 16个 warp 进入 ======================
// 通过扫描前缀和数组找到当前处理的本地专家索引,并记录在 shared_local_expert_idx
if (sub_warp_id == 0 and lane_id == 0) {
while (vaild_signal_idx >= shared_vaild_signal_prefix_sum[shared_local_expert_idx])
shared_local_expert_idx++;
}
__syncthreads();
// ===========================================
// shared_local_expert_idx: 当前处理的任务块 是哪个本地专家
// 上述 操作 确定了 当前 block 负责处理的本地专家为 shared_local_expert_idx
// 需要依据 shared_local_expert_idx 本地索引确定其他 地址
const auto local_expert_idx = shared_local_expert_idx; // 当前处理 的 本地专家索引
const auto global_expert_idx = rank * num_local_experts + local_expert_idx; // 获取 本地专家 在全局中的索引
const auto local_x = static_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = static_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// ================================ 等待相应的 comp_signal 达到阈值
//----------------------- 确定 当前等待的信号量位置
// num_tokens_per_expert:当前 负责的专家 dispatch 阶段 接收的 总 token 数
// num_signal_per_expert:当前 负责的专家 需要等待的总 信号 数
// local_expert_signal_idx: 当前处理的信号总索引,是 当前处理专家的 第几个信号
int num_tokens_per_expert, num_signal_per_expert, local_expert_signal_idx;
const int* gemm_comp_signal;
num_tokens_per_expert = packed_recv_count[local_expert_idx]; // 当前专家 dispatch 阶段接收的 总 token 数
num_signal_per_expert = ceil_div(num_ranks * num_max_dispatch_tokens_per_rank, block_m); // 每个专家的 最大 信号数
local_expert_signal_idx =
(local_expert_idx == 0) ? vaild_signal_idx : vaild_signal_idx - shared_vaild_signal_prefix_sum[local_expert_idx - 1]; // 当前专家 中的 信号索引
gemm_comp_signal = comp_signal + num_signal_per_expert * local_expert_idx + local_expert_signal_idx;
//----------------------- 循环等待 信号量到达 阈值
if (sub_warp_id == 0 and lane_id == 0 and num_tokens_per_expert != 0) { // 当前专家 dispatch 阶段接收的 token 数 不是 0 的话,循环等待 信号量的值 到达 阈值
while (ld_acquire_global(gemm_comp_signal) != threshold)
;
}
__syncthreads();
// ============================== 发射 RDMA 指令 ==============================
// ------------------------------ 确定 处理的 token 起始位置 和 结束位置 -----------------
auto token_start_idx = local_expert_signal_idx * block_m;
auto token_end_idx = min((local_expert_signal_idx + 1) * block_m, num_tokens_per_expert);
// 16个 warp 每个warp 负责一个 token 的发射
for (int token_idx = sub_warp_id + token_start_idx; token_idx < token_end_idx; token_idx += num_warps_per_group) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
const auto dst_rank = static_cast<int>(__ldg(local_src_info + token_idx) >> 32);
const auto src_idx = static_cast<int>(__ldg(local_src_info + token_idx) & 0xffffffff);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy){
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
}
internode_ll_putmem_nbi((void*)dst_ptr, (void*)buf_ptr,
num_ranks, dst_rank, local_expert_idx,
hidden * sizeof(hip_bfloat16));
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(x_int4);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
}
__syncthreads(); // 等待 16 个 warp 都完成 RDMA 发射
// ================================= 当前所有 RDMA 下发完成后,判断是不是要 发射 完成的 flag=====================================
bool put_finish_flag = false; // 标记是不是要发射 RDMA 结束标记
// 判断是不是 到了 当前专家处理的 最后
if (sub_warp_id == 0) { //
if (lane_id == 0) {
const auto finish_counter = (num_tokens_per_expert == 0 ? 1 : ceil_div(num_tokens_per_expert, block_m)); // 获取当前专家 发送的 总 的信号数
if ((atomicAdd(atomic_finish_counter_per_expert + local_expert_idx, 1) + 1) == finish_counter)
put_finish_flag = true;
}
put_finish_flag = shfl_sync(put_finish_flag, 0);
}
__syncthreads();
// 通知其他 所有 rank,当前本地专家的 token 已经发射完成
if (sub_warp_id == 0 and put_finish_flag) {
for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 64) {
while (ld_acquire_global(atomic_clean_flag) == 0);
auto dst_ptr = rdma_recv_flag + global_expert_idx;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_long_atomic_add(dst_ptr, 1, num_ranks, dst_rank, local_expert_idx);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release(reinterpret_cast<int *>(p2p_ptr), 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
if (lane_id == 0) // 清理 标记数组
atomic_finish_counter_per_expert[local_expert_idx] = 0;
}
__syncthreads();
}
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) {
const auto src_rank = responsible_expert_idx / num_local_experts;
auto start_time = wall_clock64();
uint64_t wait_recv_cost = 0;
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0 // recv not ready
&& (wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout
);
// Mask rank if timeout
if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
printf("Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d\n",
rank, responsible_expert_idx % num_local_experts, src_rank);
}
if (combine_wait_recv_cost_stats != nullptr) {
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
}
}
}
grid_barrier(global_atomic_counter, num_sms);
// Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
}
float combined_values[kNumElemsPerInt4] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
}
// Write results
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
}
void combine_sbo(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap 新增控制参数
bool disable_ll_layered,
int* packed_recv_count, int* comp_signal,
int block_m, int threshold, int num_sms,
// 同步与统计参数
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
// 维度与配置参数
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
// 系统资源与执行参数
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopk = 11;
int num_warp_groups, num_warps_per_group, num_recv_per_sm, num_warps;
if (phases == LOW_LATENCY_SEND_PHASE) { // 如果启用 overlop 必须是 send 阶段
num_warp_groups = 1; // 一个 block 只有一个 warp 组
num_warps_per_group = 16; // 16 个 warp 每个 warp 64 线程
num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0 and block_m > 0 and threshold > 0);
num_warps = num_warp_groups * num_warps_per_group;
} else {
num_warp_groups = ceil_div(num_experts, num_device_sms);
num_warps_per_group = kMaxNumWarps / num_warp_groups;
num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
num_warps = num_warp_groups * num_warps_per_group;
num_sms = max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
}
// Check workspace
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_clean_flag + 1; // overlop 新增使用
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_OVERLOP_LAUNCH_CASE(hidden) \
{ \
auto combine_overlop_func = combine_sbo<hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_overlop_func, \
disable_ll_layered, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
packed_recv_count, comp_signal, block_m, threshold, \
global_atomic_counter, combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, atomic_finish_counter_per_expert, \
num_combined_tokens, hidden, \
num_topk, num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases, zero_copy); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(COMBINE_OVERLOP_LAUNCH_CASE);
#undef COMBINE_OVERLOP_LAUNCH_CASE
}
} // namespace internode_ll } // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
...@@ -40,6 +40,8 @@ class Buffer: ...@@ -40,6 +40,8 @@ class Buffer:
allow_mnnvl: bool = False, allow_mnnvl: bool = False,
explicitly_destroy: bool = False, explicitly_destroy: bool = False,
enable_shrink: bool = False, enable_shrink: bool = False,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
) -> None: ) -> None:
""" """
Initialize the communication buffer. Initialize the communication buffer.
...@@ -60,6 +62,8 @@ class Buffer: ...@@ -60,6 +62,8 @@ class Buffer:
otherwise, the resources will be released by the destructor. otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang. Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically. enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
enable_dispatch_ll_layered: Enable low-latency mode with hierarchical dispatch operators.
enable_combine_overlap: deepgemm DOWN gemm overlop combine send
""" """
check_nvlink_connections(group) check_nvlink_connections(group)
...@@ -72,6 +76,10 @@ class Buffer: ...@@ -72,6 +76,10 @@ class Buffer:
self.low_latency_mode = low_latency_mode self.low_latency_mode = low_latency_mode
self.explicitly_destroy = explicitly_destroy self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink self.enable_shrink = enable_shrink
if enable_dispatch_ll_layered and enable_shrink: # Currently, the layered algorithm for ll dispatch has been optimized, so the shrink mode is no longer supported.
print("DeepEP [ERROR] not support shrink, disable it", flush=True)
enable_shrink = False
self.runtime = deep_ep_cpp.Buffer( self.runtime = deep_ep_cpp.Buffer(
self.rank, self.rank,
self.group_size, self.group_size,
...@@ -79,7 +87,9 @@ class Buffer: ...@@ -79,7 +87,9 @@ class Buffer:
num_rdma_bytes, num_rdma_bytes,
low_latency_mode, low_latency_mode,
explicitly_destroy, explicitly_destroy,
enable_shrink enable_shrink,
enable_dispatch_ll_layered,
enable_combine_overlap
) )
# Synchronize device IDs # Synchronize device IDs
...@@ -212,7 +222,8 @@ class Buffer: ...@@ -212,7 +222,8 @@ class Buffer:
@staticmethod @staticmethod
def get_low_latency_rdma_size_hint( def get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int, quant_group_size: int = 0 num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int,
enable_dispatch_ll_layered: bool = False, quant_group_size: int = 0
) -> int: ) -> int:
""" """
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
...@@ -228,7 +239,8 @@ class Buffer: ...@@ -228,7 +239,8 @@ class Buffer:
size: the RDMA buffer size recommended. size: the RDMA buffer size recommended.
""" """
return deep_ep_cpp.get_low_latency_rdma_size_hint( return deep_ep_cpp.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size
) )
def get_comm_stream(self) -> torch.Stream: def get_comm_stream(self) -> torch.Stream:
...@@ -921,9 +933,11 @@ class Buffer: ...@@ -921,9 +933,11 @@ class Buffer:
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple,
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, # combine sbo params
handle: tuple, use_logfmt: bool = False, packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None,
block_m: int = -1, threshold: int = -1, num_sms: int = -1,
use_logfmt: bool = False,
zero_copy: bool = False, async_finish: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
...@@ -945,13 +959,13 @@ class Buffer: ...@@ -945,13 +959,13 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`. with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
...@@ -964,6 +978,7 @@ class Buffer: ...@@ -964,6 +978,7 @@ class Buffer:
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
packed_recv_count, comp_signal, block_m, threshold, num_sms,
combine_wait_recv_cost_stats, combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook, out) use_logfmt, zero_copy, async_finish, return_recv_hook, out)
......
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1 # export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240 # export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../ ...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1 # export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240 # export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../ ...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
...@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu ...@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def ceil_div(a, b):
return (a + b - 1) // b
def test_main(num_tokens: int, def test_main(num_tokens: int,
hidden: int, hidden: int,
num_experts: int, num_experts: int,
...@@ -42,11 +46,16 @@ def test_main(num_tokens: int, ...@@ -42,11 +46,16 @@ def test_main(num_tokens: int,
num_ranks: int, num_ranks: int,
group: dist.ProcessGroup, group: dist.ProcessGroup,
buffer: deep_ep.Buffer, buffer: deep_ep.Buffer,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
use_logfmt: bool = False, use_logfmt: bool = False,
seed: int = 0): seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks num_local_experts = num_experts // num_ranks
...@@ -84,10 +93,13 @@ def test_main(num_tokens: int, ...@@ -84,10 +93,13 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list): for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2 if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop
continue
for quant_type in (0, 1, 2, 3,): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0 dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ): for fp8_round_scale in (False, True) if quant_type != 3 else (True,):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ): for quant_group_size in (0, 128,) if quant_type >= 2 else (0,):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0): if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue continue
...@@ -131,9 +143,14 @@ def test_main(num_tokens: int, ...@@ -131,9 +143,14 @@ def test_main(num_tokens: int,
recv_x = recv_x[:num_valid_tokens] recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1) recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
if (enable_dispatch_ll_layered or enable_combine_overlap):
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax) assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant: if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
...@@ -148,6 +165,7 @@ def test_main(num_tokens: int, ...@@ -148,6 +165,7 @@ def test_main(num_tokens: int,
if not fp8_round_scale: if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant: if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
...@@ -155,19 +173,42 @@ def test_main(num_tokens: int, ...@@ -155,19 +173,42 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
for zero_copy in (False, ) if use_logfmt else (False, True, ): for zero_copy in (False,) if use_logfmt else (False, True,):
if zero_copy: if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, if enable_combine_overlap:
topk_idx, block_m, threshold, num_sms = 64, 10, 3
topk_weights, total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数??
handle, comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
use_logfmt=use_logfmt,
async_finish=not return_recv_hook, for i in range(num_local_experts):
zero_copy=zero_copy, vaild_num = ceil_div(packed_recv_count[i], block_m)
return_recv_hook=return_recv_hook, comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
out=out) combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
packed_recv_count=packed_recv_count,
comp_signal=comp_signal,
block_m=block_m,
threshold=threshold,
num_sms=num_sms,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
else:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
...@@ -177,9 +218,13 @@ def test_main(num_tokens: int, ...@@ -177,9 +218,13 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
if rank == 0: if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ", print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass") f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
print("deep_ep 全部正确性测试完成")
if enable_dispatch_ll_layered or enable_combine_overlap:
return hash_value
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float) mat_0 = torch.randn((8192, 8192), dtype=torch.float)
...@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens, hidden = args.num_tokens, args.hidden num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
enable_combine_overlap = args.enable_combine_overlap
if enable_dispatch_ll_layered:
enable_combine_overlap = True
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered=enable_dispatch_ll_layered)
if local_rank == 0: if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, buffer = deep_ep.Buffer(group,
...@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks, num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True, explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl) allow_mnnvl=args.allow_mnnvl,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap
)
print("deep_ep 初始化完成")
test_main(num_tokens, test_main(num_tokens,
hidden, hidden,
num_experts, num_experts,
...@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=1) seed=1)
do_pressure_test = args.pressure_test do_pressure_test = args.pressure_test
...@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) seed=seed)
for _ in range(20): for _ in range(20):
assert test_main(num_tokens, assert test_main(num_tokens,
...@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) == ref_hash, f'Error: seed={seed}' seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group # Destroy the buffer runtime and communication group
...@@ -309,6 +370,10 @@ if __name__ == '__main__': ...@@ -309,6 +370,10 @@ if __name__ == '__main__':
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
# 新版 sbo 需要的
parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')
args = parser.parse_args() args = parser.parse_args()
num_processes = args.num_processes num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment