Commit da6da7c3 authored by lishen's avatar lishen
Browse files

low-latency添加dispatch分层优化和combine gemm overlap

parent ea76f44e
...@@ -135,9 +135,11 @@ struct LowLatencyLayout { ...@@ -135,9 +135,11 @@ struct LowLatencyLayout {
} }
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) { int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
const int num_scales = hidden / QUANTIZATION_GROUPSIZE; const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // 计算结点数
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers // - 2 symmetric odd/even receive buffers
...@@ -152,7 +154,9 @@ struct LowLatencyLayout { ...@@ -152,7 +154,9 @@ struct LowLatencyLayout {
(quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐 (quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等 // 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + num_scales * sizeof(__hip_bfloat162); size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) +
(enable_dispatch_ll_layered ? 0 : // 即enable_combine_overlap==true,执行函数combine_sbo
num_scales * sizeof(__hip_bfloat162));
// Send buffer // Send buffer
size_t dispatch_send_buffer_bytes = size_t dispatch_send_buffer_bytes =
...@@ -176,6 +180,10 @@ struct LowLatencyLayout { ...@@ -176,6 +180,10 @@ struct LowLatencyLayout {
// Symmetric signaling buffers // Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t); size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
if (enable_dispatch_ll_layered) {
dispatch_recv_count_buffer_bytes +=
NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int);
}
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128); size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
...@@ -205,9 +213,11 @@ struct LowLatencyLayout { ...@@ -205,9 +213,11 @@ struct LowLatencyLayout {
}; };
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) { int num_ranks, int num_experts,
bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
auto num_bytes = auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size) LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size)
.total_bytes; .total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES; NUM_BUFFER_ALIGNMENT_BYTES;
......
This diff is collapsed.
...@@ -35,6 +35,8 @@ private: ...@@ -35,6 +35,8 @@ private:
// Shrink mode buffer // Shrink mode buffer
bool enable_shrink = false; bool enable_shrink = false;
bool enable_dispatch_ll_layered = false;
bool enable_combine_overlap = false;
int* mask_buffer_ptr = nullptr; int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr; int* sync_buffer_ptr = nullptr;
...@@ -77,7 +79,8 @@ private: ...@@ -77,7 +79,8 @@ private:
public: public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink); bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap);
~Buffer() noexcept(false); ~Buffer() noexcept(false);
...@@ -183,6 +186,9 @@ public: ...@@ -183,6 +186,9 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool use_logfmt,
......
...@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int quant_type, int group_size, bool fp8_round_scale, int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases); void* workspace, int num_device_sms, hipStream_t stream, int phases);
void dispatch_ll_layered(bool dispatch_ll_dispatch_opt,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
...@@ -163,6 +177,24 @@ void combine(void* combined_x, ...@@ -163,6 +177,24 @@ void combine(void* combined_x,
void* workspace, int num_device_sms, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy); int phases, bool zero_copy);
void combine_sbo(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap 新增控制参数
bool disable_ll_layered,
int* packed_recv_count, int* comp_signal,
int block_m, int threshold, int num_sms,
// 同步与统计参数
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
// 维度与配置参数
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
// 系统资源与执行参数
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll } // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
This diff is collapsed.
...@@ -40,6 +40,8 @@ class Buffer: ...@@ -40,6 +40,8 @@ class Buffer:
allow_mnnvl: bool = False, allow_mnnvl: bool = False,
explicitly_destroy: bool = False, explicitly_destroy: bool = False,
enable_shrink: bool = False, enable_shrink: bool = False,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
) -> None: ) -> None:
""" """
Initialize the communication buffer. Initialize the communication buffer.
...@@ -60,6 +62,8 @@ class Buffer: ...@@ -60,6 +62,8 @@ class Buffer:
otherwise, the resources will be released by the destructor. otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang. Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically. enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
enable_dispatch_ll_layered: Enable low-latency mode with hierarchical dispatch operators.
enable_combine_overlap: deepgemm DOWN gemm overlop combine send
""" """
check_nvlink_connections(group) check_nvlink_connections(group)
...@@ -72,6 +76,10 @@ class Buffer: ...@@ -72,6 +76,10 @@ class Buffer:
self.low_latency_mode = low_latency_mode self.low_latency_mode = low_latency_mode
self.explicitly_destroy = explicitly_destroy self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink self.enable_shrink = enable_shrink
if enable_dispatch_ll_layered and enable_shrink: # Currently, the layered algorithm for ll dispatch has been optimized, so the shrink mode is no longer supported.
print("DeepEP [ERROR] not support shrink, disable it", flush=True)
enable_shrink = False
self.runtime = deep_ep_cpp.Buffer( self.runtime = deep_ep_cpp.Buffer(
self.rank, self.rank,
self.group_size, self.group_size,
...@@ -79,7 +87,9 @@ class Buffer: ...@@ -79,7 +87,9 @@ class Buffer:
num_rdma_bytes, num_rdma_bytes,
low_latency_mode, low_latency_mode,
explicitly_destroy, explicitly_destroy,
enable_shrink enable_shrink,
enable_dispatch_ll_layered,
enable_combine_overlap
) )
# Synchronize device IDs # Synchronize device IDs
...@@ -212,7 +222,8 @@ class Buffer: ...@@ -212,7 +222,8 @@ class Buffer:
@staticmethod @staticmethod
def get_low_latency_rdma_size_hint( def get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int, quant_group_size: int = 0 num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int,
enable_dispatch_ll_layered: bool = False, quant_group_size: int = 0
) -> int: ) -> int:
""" """
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
...@@ -228,7 +239,8 @@ class Buffer: ...@@ -228,7 +239,8 @@ class Buffer:
size: the RDMA buffer size recommended. size: the RDMA buffer size recommended.
""" """
return deep_ep_cpp.get_low_latency_rdma_size_hint( return deep_ep_cpp.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size
) )
def get_comm_stream(self) -> torch.Stream: def get_comm_stream(self) -> torch.Stream:
...@@ -921,9 +933,11 @@ class Buffer: ...@@ -921,9 +933,11 @@ class Buffer:
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple,
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, # combine sbo params
handle: tuple, use_logfmt: bool = False, packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None,
block_m: int = -1, threshold: int = -1, num_sms: int = -1,
use_logfmt: bool = False,
zero_copy: bool = False, async_finish: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
...@@ -945,13 +959,13 @@ class Buffer: ...@@ -945,13 +959,13 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`. with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
...@@ -964,6 +978,7 @@ class Buffer: ...@@ -964,6 +978,7 @@ class Buffer:
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
packed_recv_count, comp_signal, block_m, threshold, num_sms,
combine_wait_recv_cost_stats, combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook, out) use_logfmt, zero_copy, async_finish, return_recv_hook, out)
......
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1 # export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240 # export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../ ...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
#!/bin/bash
# rocSHMEM # rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1 # export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240 # export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../ ...@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
...@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu ...@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def ceil_div(a, b):
return (a + b - 1) // b
def test_main(num_tokens: int, def test_main(num_tokens: int,
hidden: int, hidden: int,
num_experts: int, num_experts: int,
...@@ -42,11 +46,16 @@ def test_main(num_tokens: int, ...@@ -42,11 +46,16 @@ def test_main(num_tokens: int,
num_ranks: int, num_ranks: int,
group: dist.ProcessGroup, group: dist.ProcessGroup,
buffer: deep_ep.Buffer, buffer: deep_ep.Buffer,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
use_logfmt: bool = False, use_logfmt: bool = False,
seed: int = 0): seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks num_local_experts = num_experts // num_ranks
...@@ -84,10 +93,13 @@ def test_main(num_tokens: int, ...@@ -84,10 +93,13 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list): for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2 if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop
continue
for quant_type in (0, 1, 2, 3,): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0 dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ): for fp8_round_scale in (False, True) if quant_type != 3 else (True,):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ): for quant_group_size in (0, 128,) if quant_type >= 2 else (0,):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0): if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue continue
...@@ -131,9 +143,14 @@ def test_main(num_tokens: int, ...@@ -131,9 +143,14 @@ def test_main(num_tokens: int,
recv_x = recv_x[:num_valid_tokens] recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1) recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
if (enable_dispatch_ll_layered or enable_combine_overlap):
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax) assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant: if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
...@@ -148,6 +165,7 @@ def test_main(num_tokens: int, ...@@ -148,6 +165,7 @@ def test_main(num_tokens: int,
if not fp8_round_scale: if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant: if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
...@@ -155,19 +173,42 @@ def test_main(num_tokens: int, ...@@ -155,19 +173,42 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
for zero_copy in (False, ) if use_logfmt else (False, True, ): for zero_copy in (False,) if use_logfmt else (False, True,):
if zero_copy: if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, if enable_combine_overlap:
topk_idx, block_m, threshold, num_sms = 64, 10, 3
topk_weights, total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数??
handle, comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
use_logfmt=use_logfmt,
async_finish=not return_recv_hook, for i in range(num_local_experts):
zero_copy=zero_copy, vaild_num = ceil_div(packed_recv_count[i], block_m)
return_recv_hook=return_recv_hook, comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
out=out) combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
packed_recv_count=packed_recv_count,
comp_signal=comp_signal,
block_m=block_m,
threshold=threshold,
num_sms=num_sms,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
else:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
...@@ -177,9 +218,13 @@ def test_main(num_tokens: int, ...@@ -177,9 +218,13 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
if rank == 0: if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ", print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass") f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
print("deep_ep 全部正确性测试完成")
if enable_dispatch_ll_layered or enable_combine_overlap:
return hash_value
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float) mat_0 = torch.randn((8192, 8192), dtype=torch.float)
...@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens, hidden = args.num_tokens, args.hidden num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
enable_combine_overlap = args.enable_combine_overlap
if enable_dispatch_ll_layered:
enable_combine_overlap = True
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered=enable_dispatch_ll_layered)
if local_rank == 0: if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, buffer = deep_ep.Buffer(group,
...@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks, num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True, explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl) allow_mnnvl=args.allow_mnnvl,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap
)
print("deep_ep 初始化完成")
test_main(num_tokens, test_main(num_tokens,
hidden, hidden,
num_experts, num_experts,
...@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=1) seed=1)
do_pressure_test = args.pressure_test do_pressure_test = args.pressure_test
...@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) seed=seed)
for _ in range(20): for _ in range(20):
assert test_main(num_tokens, assert test_main(num_tokens,
...@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt, use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) == ref_hash, f'Error: seed={seed}' seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group # Destroy the buffer runtime and communication group
...@@ -309,6 +370,10 @@ if __name__ == '__main__': ...@@ -309,6 +370,10 @@ if __name__ == '__main__':
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
# 新版 sbo 需要的
parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')
args = parser.parse_args() args = parser.parse_args()
num_processes = args.num_processes num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment