Commit c2abba11 authored by lishen's avatar lishen
Browse files

Merge branch 'sbo_tmp' into 'main'

low-latency添加dispatch分层优化和combine gemm overlap

See merge request dcutoolkit/deeplearing/DeepEP!26
parents ea76f44e da6da7c3
...@@ -135,9 +135,11 @@ struct LowLatencyLayout { ...@@ -135,9 +135,11 @@ struct LowLatencyLayout {
} }
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) { int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
const int num_scales = hidden / QUANTIZATION_GROUPSIZE; const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // 计算结点数
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers // - 2 symmetric odd/even receive buffers
...@@ -152,7 +154,9 @@ struct LowLatencyLayout { ...@@ -152,7 +154,9 @@ struct LowLatencyLayout {
(quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐 (quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等 // 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + num_scales * sizeof(__hip_bfloat162); size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) +
(enable_dispatch_ll_layered ? 0 : // 即enable_combine_overlap==true,执行函数combine_sbo
num_scales * sizeof(__hip_bfloat162));
// Send buffer // Send buffer
size_t dispatch_send_buffer_bytes = size_t dispatch_send_buffer_bytes =
...@@ -176,6 +180,10 @@ struct LowLatencyLayout { ...@@ -176,6 +180,10 @@ struct LowLatencyLayout {
// Symmetric signaling buffers // Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t); size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
if (enable_dispatch_ll_layered) {
dispatch_recv_count_buffer_bytes +=
NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int);
}
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128); size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
...@@ -205,9 +213,11 @@ struct LowLatencyLayout { ...@@ -205,9 +213,11 @@ struct LowLatencyLayout {
}; };
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) { int num_ranks, int num_experts,
bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
auto num_bytes = auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size) LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size)
.total_bytes; .total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES; NUM_BUFFER_ALIGNMENT_BYTES;
......
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
namespace deep_ep { namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink) bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap)
: rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), : rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode), num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy), explicitly_destroy(explicitly_destroy),
enable_shrink(enable_shrink), enable_shrink(enable_shrink),
enable_dispatch_ll_layered(enable_dispatch_ll_layered),
enable_combine_overlap(enable_combine_overlap),
comm_stream(at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) { comm_stream(at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) {
// Metadata memory // Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
...@@ -25,6 +28,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ ...@@ -25,6 +28,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *); int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *);
EP_HOST_ASSERT(enable_shrink == false); EP_HOST_ASSERT(enable_shrink == false);
if (enable_dispatch_ll_layered)
EP_HOST_ASSERT(enable_combine_overlap == true);
// Common checks // Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
...@@ -1274,7 +1279,8 @@ Buffer::internode_combine( ...@@ -1274,7 +1279,8 @@ Buffer::internode_combine(
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts, int quant_group_size) { void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts, int quant_group_size) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size);
auto clean_meta_0 = layout.buffers[0].clean_meta(); auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta(); auto clean_meta_1 = layout.buffers[1].clean_meta();
...@@ -1311,7 +1317,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1311,7 +1317,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto num_local_experts = num_experts / num_ranks; auto num_local_experts = num_experts / num_ranks;
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered, quant_group_size);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
...@@ -1336,7 +1342,16 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1336,7 +1342,16 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
} }
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype)); auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::Dtype dtype = torch::kInt32;
if(enable_dispatch_ll_layered or enable_combine_overlap){
dtype = torch::kInt64;
}
auto packed_recv_src_info = torch::empty(
{num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(dtype).device(torch::kCUDA)
);
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
...@@ -1371,10 +1386,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1371,10 +1386,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
} }
if(!enable_dispatch_ll_layered){
// Kernel launch // Kernel launch
auto next_clean_meta = next_buffer.clean_meta(); auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) { auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, internode_ll::dispatch(
packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(), packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(), packed_recv_count.data_ptr<int>(),
global_atomic_counter.data_ptr<int>(), global_atomic_counter.data_ptr<int>(),
...@@ -1406,17 +1423,72 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1406,17 +1423,72 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Return values // Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
} else {
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch_ll_layered(
!enable_dispatch_ll_layered,
packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int64_t>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
global_atomic_counter.data_ptr<int>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
}
} }
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) { const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// combine overlap checks
EP_HOST_ASSERT((!enable_combine_overlap || return_recv_hook) and "Overlap mode requires return_recv_hook=True"); // 启用 overlap 时, 必须 hook = True
EP_HOST_ASSERT((!enable_combine_overlap || packed_recv_count.has_value()) && "Overlap mode requires packed_recv_count has value");
EP_HOST_ASSERT((!enable_combine_overlap || comp_signal.has_value()) && "Overlap mode requires comp_signal has value");
EP_HOST_ASSERT((!enable_combine_overlap || block_m != -1) && "Overlap mode requires block_m != -1");
EP_HOST_ASSERT((!enable_combine_overlap || threshold != -1) && "Overlap mode requires threshold != -1");
EP_HOST_ASSERT((!enable_combine_overlap || num_sms != -1) && "Overlap mode requires num_sms != -1");
if (comp_signal.has_value()) {
EP_HOST_ASSERT(comp_signal->dim() == 1 and comp_signal->is_contiguous());
EP_HOST_ASSERT(comp_signal->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(comp_signal->size(0) == num_experts / num_ranks * ((num_ranks * num_max_dispatch_tokens_per_rank + 63) / 64));
}
// Tensor checks // Tensor checks
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
...@@ -1430,7 +1502,12 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1430,7 +1502,12 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
if (!enable_dispatch_ll_layered && !enable_combine_overlap) {
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
} else {
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt64 and x.size(0) == src_info.size(0));
}
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
...@@ -1446,7 +1523,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1446,7 +1523,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
auto num_combined_tokens = static_cast<int>(topk_weights.size(0)); auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
...@@ -1472,8 +1549,10 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1472,8 +1549,10 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Kernel launch // Kernel launch
auto next_clean_meta = next_buffer.clean_meta(); auto next_clean_meta = next_buffer.clean_meta();
if(!enable_combine_overlap) {
auto launcher = [=](int phases) { auto launcher = [=](int phases) {
internode_ll::combine(combined_x.data_ptr(), internode_ll::combine(
combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer, buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
...@@ -1506,10 +1585,55 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1506,10 +1585,55 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Return values // Return values
return {combined_x, event, recv_hook}; return {combined_x, event, recv_hook};
} else {
auto launcher = [=](int phases) {
internode_ll::combine_sbo(
combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer,
buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int64_t>(), layout_range.data_ptr<int64_t>(),
/* ll_layered 新增参数 */
!enable_dispatch_ll_layered,
/* overlap 新增参数 */
packed_recv_count.has_value() ? packed_recv_count->data_ptr<int>() : nullptr,
comp_signal.has_value() ? comp_signal->data_ptr<int>() : nullptr,
block_m, threshold, num_sms,
/* 辅助tensor */
global_atomic_counter.data_ptr<int>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, num_device_sms, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {combined_x, event, recv_hook};
}
} }
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16; auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
...@@ -1540,7 +1664,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1540,7 +1664,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer") pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool>()) .def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool, bool>())
.def("is_available", &deep_ep::Buffer::is_available) .def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
......
...@@ -35,6 +35,8 @@ private: ...@@ -35,6 +35,8 @@ private:
// Shrink mode buffer // Shrink mode buffer
bool enable_shrink = false; bool enable_shrink = false;
bool enable_dispatch_ll_layered = false;
bool enable_combine_overlap = false;
int* mask_buffer_ptr = nullptr; int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr; int* sync_buffer_ptr = nullptr;
...@@ -77,7 +79,8 @@ private: ...@@ -77,7 +79,8 @@ private:
public: public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink); bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap);
~Buffer() noexcept(false); ~Buffer() noexcept(false);
...@@ -183,6 +186,9 @@ public: ...@@ -183,6 +186,9 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool use_logfmt,
......
...@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int quant_type, int group_size, bool fp8_round_scale, int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases); void* workspace, int num_device_sms, hipStream_t stream, int phases);
void dispatch_ll_layered(bool dispatch_ll_dispatch_opt,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
...@@ -163,6 +177,24 @@ void combine(void* combined_x, ...@@ -163,6 +177,24 @@ void combine(void* combined_x,
void* workspace, int num_device_sms, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy); int phases, bool zero_copy);
void combine_sbo(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap 新增控制参数
bool disable_ll_layered,
int* packed_recv_count, int* comp_signal,
int block_m, int threshold, int num_sms,
// 同步与统计参数
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
// 维度与配置参数
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
// 系统资源与执行参数
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll } // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
This diff is collapsed.
...@@ -40,6 +40,8 @@ class Buffer: ...@@ -40,6 +40,8 @@ class Buffer:
allow_mnnvl: bool = False, allow_mnnvl: bool = False,
explicitly_destroy: bool = False, explicitly_destroy: bool = False,
enable_shrink: bool = False, enable_shrink: bool = False,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
) -> None: ) -> None:
""" """
Initialize the communication buffer. Initialize the communication buffer.
...@@ -60,6 +62,8 @@ class Buffer: ...@@ -60,6 +62,8 @@ class Buffer:
otherwise, the resources will be released by the destructor. otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang. Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically. enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
enable_dispatch_ll_layered: Enable low-latency mode with hierarchical dispatch operators.
enable_combine_overlap: deepgemm DOWN gemm overlop combine send
""" """
check_nvlink_connections(group) check_nvlink_connections(group)
...@@ -72,6 +76,10 @@ class Buffer: ...@@ -72,6 +76,10 @@ class Buffer:
self.low_latency_mode = low_latency_mode self.low_latency_mode = low_latency_mode
self.explicitly_destroy = explicitly_destroy self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink self.enable_shrink = enable_shrink
if enable_dispatch_ll_layered and enable_shrink: # Currently, the layered algorithm for ll dispatch has been optimized, so the shrink mode is no longer supported.
print("DeepEP [ERROR] not support shrink, disable it", flush=True)
enable_shrink = False
self.runtime = deep_ep_cpp.Buffer( self.runtime = deep_ep_cpp.Buffer(
self.rank, self.rank,
self.group_size, self.group_size,
...@@ -79,7 +87,9 @@ class Buffer: ...@@ -79,7 +87,9 @@ class Buffer:
num_rdma_bytes, num_rdma_bytes,
low_latency_mode, low_latency_mode,
explicitly_destroy, explicitly_destroy,
enable_shrink enable_shrink,
enable_dispatch_ll_layered,
enable_combine_overlap
) )
# Synchronize device IDs # Synchronize device IDs
...@@ -212,7 +222,8 @@ class Buffer: ...@@ -212,7 +222,8 @@ class Buffer:
@staticmethod @staticmethod
def get_low_latency_rdma_size_hint( def get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int, quant_group_size: int = 0 num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int,
enable_dispatch_ll_layered: bool = False, quant_group_size: int = 0
) -> int: ) -> int:
""" """
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
...@@ -228,7 +239,8 @@ class Buffer: ...@@ -228,7 +239,8 @@ class Buffer:
size: the RDMA buffer size recommended. size: the RDMA buffer size recommended.
""" """
return deep_ep_cpp.get_low_latency_rdma_size_hint( return deep_ep_cpp.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size
) )
def get_comm_stream(self) -> torch.Stream: def get_comm_stream(self) -> torch.Stream:
...@@ -921,9 +933,11 @@ class Buffer: ...@@ -921,9 +933,11 @@ class Buffer:
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple,
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, # combine sbo params
handle: tuple, use_logfmt: bool = False, packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None,
block_m: int = -1, threshold: int = -1, num_sms: int = -1,
use_logfmt: bool = False,
zero_copy: bool = False, async_finish: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
...@@ -945,13 +959,13 @@ class Buffer: ...@@ -945,13 +959,13 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`. with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
...@@ -964,6 +978,7 @@ class Buffer: ...@@ -964,6 +978,7 @@ class Buffer:
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
packed_recv_count, comp_signal, block_m, threshold, num_sms,
combine_wait_recv_cost_stats, combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook, out) use_logfmt, zero_copy, async_finish, return_recv_hook, out)
......
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1 # export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240 # export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../ ...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1 # export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240 # export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../ ...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
...@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu ...@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def ceil_div(a, b):
return (a + b - 1) // b
def test_main(num_tokens: int, def test_main(num_tokens: int,
hidden: int, hidden: int,
num_experts: int, num_experts: int,
...@@ -42,11 +46,16 @@ def test_main(num_tokens: int, ...@@ -42,11 +46,16 @@ def test_main(num_tokens: int,
num_ranks: int, num_ranks: int,
group: dist.ProcessGroup, group: dist.ProcessGroup,
buffer: deep_ep.Buffer, buffer: deep_ep.Buffer,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
use_logfmt: bool = False, use_logfmt: bool = False,
seed: int = 0): seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks num_local_experts = num_experts // num_ranks
...@@ -84,10 +93,13 @@ def test_main(num_tokens: int, ...@@ -84,10 +93,13 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list): for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2 if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop
continue
for quant_type in (0, 1, 2, 3,): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0 dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ): for fp8_round_scale in (False, True) if quant_type != 3 else (True,):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ): for quant_group_size in (0, 128,) if quant_type >= 2 else (0,):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0): if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue continue
...@@ -131,7 +143,12 @@ def test_main(num_tokens: int, ...@@ -131,7 +143,12 @@ def test_main(num_tokens: int,
recv_x = recv_x[:num_valid_tokens] recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1) recv_x_amax = recv_x[:, :-128].amax(dim=-1)
if (enable_dispatch_ll_layered or enable_combine_overlap):
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens] recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax) assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant: if dispatch_use_quant:
...@@ -148,6 +165,7 @@ def test_main(num_tokens: int, ...@@ -148,6 +165,7 @@ def test_main(num_tokens: int,
if not fp8_round_scale: if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant: if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
...@@ -155,10 +173,32 @@ def test_main(num_tokens: int, ...@@ -155,10 +173,32 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
for zero_copy in (False, ) if use_logfmt else (False, True, ): for zero_copy in (False,) if use_logfmt else (False, True,):
if zero_copy: if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
if enable_combine_overlap:
block_m, threshold, num_sms = 64, 10, 3
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数??
comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
for i in range(num_local_experts):
vaild_num = ceil_div(packed_recv_count[i], block_m)
comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
packed_recv_count=packed_recv_count,
comp_signal=comp_signal,
block_m=block_m,
threshold=threshold,
num_sms=num_sms,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
else:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx, topk_idx,
topk_weights, topk_weights,
...@@ -168,6 +208,7 @@ def test_main(num_tokens: int, ...@@ -168,6 +208,7 @@ def test_main(num_tokens: int,
zero_copy=zero_copy, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, return_recv_hook=return_recv_hook,
out=out) out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
...@@ -180,6 +221,10 @@ def test_main(num_tokens: int, ...@@ -180,6 +221,10 @@ def test_main(num_tokens: int,
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ", print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass") f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
print("deep_ep 全部正确性测试完成")
if enable_dispatch_ll_layered or enable_combine_overlap:
return hash_value
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float) mat_0 = torch.randn((8192, 8192), dtype=torch.float)
...@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens, hidden = args.num_tokens, args.hidden num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
enable_combine_overlap = args.enable_combine_overlap
if enable_dispatch_ll_layered:
enable_combine_overlap = True
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered=enable_dispatch_ll_layered)
if local_rank == 0: if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, buffer = deep_ep.Buffer(group,
...@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks, num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True, explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl) allow_mnnvl=args.allow_mnnvl,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap
)
print("deep_ep 初始化完成")
test_main(num_tokens, test_main(num_tokens,
hidden, hidden,
num_experts, num_experts,
...@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=1) seed=1)
do_pressure_test = args.pressure_test do_pressure_test = args.pressure_test
...@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) seed=seed)
for _ in range(20): for _ in range(20):
assert test_main(num_tokens, assert test_main(num_tokens,
...@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) == ref_hash, f'Error: seed={seed}' seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group # Destroy the buffer runtime and communication group
...@@ -309,6 +370,10 @@ if __name__ == '__main__': ...@@ -309,6 +370,10 @@ if __name__ == '__main__':
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
# 新版 sbo 需要的
parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')
args = parser.parse_args() args = parser.parse_args()
num_processes = args.num_processes num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment