Commit da6da7c3 authored by lishen's avatar lishen
Browse files

low-latency添加dispatch分层优化和combine gemm overlap

parent ea76f44e
......@@ -135,9 +135,11 @@ struct LowLatencyLayout {
}
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) {
int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // 计算结点数
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
......@@ -152,7 +154,9 @@ struct LowLatencyLayout {
(quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + num_scales * sizeof(__hip_bfloat162);
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) +
(enable_dispatch_ll_layered ? 0 : // 即enable_combine_overlap==true,执行函数combine_sbo
num_scales * sizeof(__hip_bfloat162));
// Send buffer
size_t dispatch_send_buffer_bytes =
......@@ -176,6 +180,10 @@ struct LowLatencyLayout {
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
if (enable_dispatch_ll_layered) {
dispatch_recv_count_buffer_bytes +=
NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int);
}
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
......@@ -205,9 +213,11 @@ struct LowLatencyLayout {
};
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) {
int num_ranks, int num_experts,
bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size)
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size)
.total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES;
......
This diff is collapsed.
......@@ -35,6 +35,8 @@ private:
// Shrink mode buffer
bool enable_shrink = false;
bool enable_dispatch_ll_layered = false;
bool enable_combine_overlap = false;
int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr;
......@@ -77,7 +79,8 @@ private:
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink);
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap);
~Buffer() noexcept(false);
......@@ -183,6 +186,9 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt,
......
......@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void dispatch_ll_layered(bool dispatch_ll_dispatch_opt,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
......@@ -163,6 +177,24 @@ void combine(void* combined_x,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
void combine_sbo(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap 新增控制参数
bool disable_ll_layered,
int* packed_recv_count, int* comp_signal,
int block_m, int threshold, int num_sms,
// 同步与统计参数
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
// 维度与配置参数
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
// 系统资源与执行参数
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll
} // namespace deep_ep
This diff is collapsed.
......@@ -40,6 +40,8 @@ class Buffer:
allow_mnnvl: bool = False,
explicitly_destroy: bool = False,
enable_shrink: bool = False,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
) -> None:
"""
Initialize the communication buffer.
......@@ -60,6 +62,8 @@ class Buffer:
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
enable_dispatch_ll_layered: Enable low-latency mode with hierarchical dispatch operators.
enable_combine_overlap: deepgemm DOWN gemm overlop combine send
"""
check_nvlink_connections(group)
......@@ -72,6 +76,10 @@ class Buffer:
self.low_latency_mode = low_latency_mode
self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink
if enable_dispatch_ll_layered and enable_shrink: # Currently, the layered algorithm for ll dispatch has been optimized, so the shrink mode is no longer supported.
print("DeepEP [ERROR] not support shrink, disable it", flush=True)
enable_shrink = False
self.runtime = deep_ep_cpp.Buffer(
self.rank,
self.group_size,
......@@ -79,7 +87,9 @@ class Buffer:
num_rdma_bytes,
low_latency_mode,
explicitly_destroy,
enable_shrink
enable_shrink,
enable_dispatch_ll_layered,
enable_combine_overlap
)
# Synchronize device IDs
......@@ -212,7 +222,8 @@ class Buffer:
@staticmethod
def get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int, quant_group_size: int = 0
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int,
enable_dispatch_ll_layered: bool = False, quant_group_size: int = 0
) -> int:
"""
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
......@@ -228,7 +239,8 @@ class Buffer:
size: the RDMA buffer size recommended.
"""
return deep_ep_cpp.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size
)
def get_comm_stream(self) -> torch.Stream:
......@@ -921,9 +933,11 @@ class Buffer:
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, use_logfmt: bool = False,
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple,
# combine sbo params
packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None,
block_m: int = -1, threshold: int = -1, num_sms: int = -1,
use_logfmt: bool = False,
zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
......@@ -945,13 +959,13 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
......@@ -964,6 +978,7 @@ class Buffer:
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
packed_recv_count, comp_signal, block_m, threshold, num_sms,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook, out)
......
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
# export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
......@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
# export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
......@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
......@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def ceil_div(a, b):
return (a + b - 1) // b
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
......@@ -42,11 +46,16 @@ def test_main(num_tokens: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
......@@ -84,10 +93,13 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop
continue
for quant_type in (0, 1, 2, 3,): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
for fp8_round_scale in (False, True) if quant_type != 3 else (True,):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0,):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
......@@ -131,9 +143,14 @@ def test_main(num_tokens: int,
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
if (enable_dispatch_ll_layered or enable_combine_overlap):
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
......@@ -148,6 +165,7 @@ def test_main(num_tokens: int,
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
......@@ -155,19 +173,42 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, ) if use_logfmt else (False, True, ):
for zero_copy in (False,) if use_logfmt else (False, True,):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
if enable_combine_overlap:
block_m, threshold, num_sms = 64, 10, 3
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数??
comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
for i in range(num_local_experts):
vaild_num = ceil_div(packed_recv_count[i], block_m)
comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
packed_recv_count=packed_recv_count,
comp_signal=comp_signal,
block_m=block_m,
threshold=threshold,
num_sms=num_sms,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
else:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
......@@ -177,9 +218,13 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(combined_x)
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
print("deep_ep 全部正确性测试完成")
if enable_dispatch_ll_layered or enable_combine_overlap:
return hash_value
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
......@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
enable_combine_overlap = args.enable_combine_overlap
if enable_dispatch_ll_layered:
enable_combine_overlap = True
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered=enable_dispatch_ll_layered)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
......@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
allow_mnnvl=args.allow_mnnvl,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap
)
print("deep_ep 初始化完成")
test_main(num_tokens,
hidden,
num_experts,
......@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=1)
do_pressure_test = args.pressure_test
......@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
......@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
......@@ -309,6 +370,10 @@ if __name__ == '__main__':
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
# 新版 sbo 需要的
parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')
args = parser.parse_args()
num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment