Commit 30aa7a87 authored by lishen01's avatar lishen01
Browse files

完善torchrun和mpi启动的测试代码

parent 243eca85
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
......@@ -5,6 +6,7 @@ export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
......@@ -18,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
......@@ -5,6 +6,7 @@ export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
......@@ -18,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
......@@ -244,7 +244,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank)
for i in (24, ):
for i in (60, ):
test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
......
......@@ -52,8 +52,9 @@ def test_main(num_tokens: int,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
if rank == 0:
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0
......@@ -144,7 +145,7 @@ def test_main(num_tokens: int,
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
if (enable_dispatch_ll_layered or enable_combine_overlap):
if enable_dispatch_ll_layered or enable_combine_overlap:
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
......@@ -179,7 +180,7 @@ def test_main(num_tokens: int,
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
if enable_combine_overlap:
block_m, threshold, num_sms = 64, 10, 3
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数??
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数
comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
for i in range(num_local_experts):
......
......@@ -8,12 +8,15 @@ export PYTHONPATH=$(pwd)
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_MAX_NUM_CONTEXTS=60
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=$(pwd)/tests_mpi/topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
# export NVSHMEM_SYMMETRIC_SIZE=10737418240
......@@ -145,7 +145,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
if not is_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
max_weights = recv_topk_weights.amax(dim=1, keepdim=True) # Shape: [Batch, 1]
recv_topk_weights = torch.where(recv_topk_idx == -1, max_weights, recv_topk_weights)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
......@@ -203,6 +204,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 45, 4):
for rdma_chunk_size in range(4, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'), suppress_kineto_output=True)
......@@ -235,6 +238,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 8, 1):
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
if rdma_buffer_size % rdma_chunk_size != 0:
continue
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True)
......@@ -272,8 +277,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes_ll = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
num_sms = 48
num_sms = 60
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
deep_ep.Buffer.set_num_sms(num_sms)
hidden_bytes = get_hidden_bytes(args)
num_nvl_bytes, num_rdma_bytes, num_rdma_bytes_norm = 0, 0, 0
......@@ -299,7 +305,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
if rank == 0:
print(f'{ref_hash=}')
print(f'ref_hash={ref_hash}')
print('', flush=True)
for j in range(20):
......
......@@ -119,7 +119,8 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
# Check `topk_weights`
recv_topk_weights_clone = recv_topk_weights.clone()
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
max_weights = recv_topk_weights.amax(dim=1, keepdim=True) # Shape: [Batch, 1]
recv_topk_weights = torch.where(recv_topk_idx == -1, max_weights, recv_topk_weights)
check_data(recv_topk_weights, rank_prefix_matrix)
# Test `num_worst_tokens != 0`
......@@ -251,7 +252,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank)
for i in (48, ):
for i in (60, ):
test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
......
......@@ -36,6 +36,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def ceil_div(a, b):
return (a + b - 1) // b
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
......@@ -44,11 +48,17 @@ def test_main(num_tokens: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
if rank == 0:
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
......@@ -86,6 +96,9 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True):
if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop
continue
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
......@@ -133,7 +146,12 @@ def test_main(num_tokens: int,
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
if enable_dispatch_ll_layered or enable_combine_overlap:
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
......@@ -150,6 +168,7 @@ def test_main(num_tokens: int,
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
......@@ -161,6 +180,28 @@ def test_main(num_tokens: int,
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
if enable_combine_overlap:
block_m, threshold, num_sms = 64, 10, 3
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数
comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
for i in range(num_local_experts):
vaild_num = ceil_div(packed_recv_count[i], block_m)
comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
packed_recv_count=packed_recv_count,
comp_signal=comp_signal,
block_m=block_m,
threshold=threshold,
num_sms=num_sms,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
else:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
......@@ -170,6 +211,7 @@ def test_main(num_tokens: int,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
......@@ -181,8 +223,10 @@ def test_main(num_tokens: int,
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
if rank == 0:
print('', flush=True)
print("deep_ep 全部正确性测试完成")
if enable_dispatch_ll_layered or enable_combine_overlap:
return hash_value
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
......@@ -252,9 +296,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
print(f"num_tokens, hidden, num_ranks, num_experts = {num_tokens}, {hidden}, {num_ranks}, {num_experts}")
enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
enable_combine_overlap = args.enable_combine_overlap
if enable_dispatch_ll_layered:
enable_combine_overlap = True
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered=enable_dispatch_ll_layered)
if rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
......@@ -263,7 +311,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
allow_mnnvl=args.allow_mnnvl,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap
)
print("deep_ep 初始化完成")
test_main(num_tokens,
hidden,
num_experts,
......@@ -273,6 +325,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=1)
do_pressure_test = args.pressure_test
......@@ -288,6 +342,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
......@@ -299,6 +355,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
......@@ -331,6 +389,10 @@ if __name__ == '__main__':
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
# 新版 sbo 需要的
parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')
args = parser.parse_args()
if args.world_size > args.num_processes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment