Commit 61bc0aff authored by lishen's avatar lishen
Browse files

修改torchrun启动的测试

parent e195b4fe
...@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 ...@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_TOPO_FILE_FORCE=tests/topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
...@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240 ...@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$(pwd) export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py --pressure-test torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
...@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288 ...@@ -4,7 +4,7 @@ export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_TOPO_FILE_FORCE=tests/topo.config export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM # duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
...@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240 ...@@ -13,10 +13,10 @@ export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common # common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH=$(pwd) export PYTHONPATH=$(pwd)/../
# test # test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py --pressure-test torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
import argparse
import random import random
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from functools import partial from functools import partial
from typing import Literal, Set
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0): # Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks = {
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch': 1,
'combine': 3,
'clean': 5
}
if rank in expected_masked_ranks:
# Rank already failed
return True
if api in failed_api_ranks.keys():
expected_masked_ranks.add(failed_api_ranks[api])
if failed_api_ranks[api] == rank:
print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True)
return True
return False
def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor,
expected_masked_ranks: Set[int]):
buffer.low_latency_query_mask_buffer(mask_status)
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit # NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128 rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
for _ in range(4 if use_logfmt else 0):
# NOTES: make more LogFMT casts and also with some BF16
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions # Randomly mask some positions
for i in range(10): for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
expected_masked_ranks = set()
# Check dispatch correctness # Check dispatch correctness
do_check = True do_check = True
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True): for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
num_times += 1 num_times += 1
for i in range((num_times % 2) + 1): for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \ packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8)) if not dispatch_use_quant:
# return simulated_gemm_x = packed_recv_x.clone()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x elif quant_group_size == 0:
simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape)
if dispatch_use_fp8 else packed_recv_x.clone() elif quant_group_size == 128:
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n") simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
# print(f"simulated_gemm_x{simulated_gemm_x.cpu()}")
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0): for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i expert_id = rank * num_local_experts + i
recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] if not dispatch_use_quant:
recv_x = packed_recv_x[i]
elif quant_group_size == 0:
recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i])
elif quant_group_size == 128:
recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices # Check expert indices
int_mask = (2 ** 32) - 1 int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item() num_valid_tokens = recv_count.item()
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' assert num_valid_tokens == (
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data # Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens] recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens] recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if quant_group_size != 0:
if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks): for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0 assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8: if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else: else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
for zero_copy in (False, True): for zero_copy in (False, ) if use_logfmt else (False, True, ):
if zero_copy: if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook, async_finish=not return_recv_hook,
return_recv_hook=return_recv_hook, out=out) zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0 assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f'Error: diff={diff}' # if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float) mat_0 = torch.randn((8192, 8192), dtype=torch.float)
...@@ -101,19 +188,23 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -101,19 +188,23 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook() hook()
# noinspection PyShadowingNames # noinspection PyShadowingNames
def test_func(zero_copy: bool, return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
quant_type=2, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook) async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy: combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x topk_idx,
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, topk_weights,
zero_copy=zero_copy, return_recv_hook=return_recv_hook) handle,
use_logfmt=use_logfmt,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth # Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens): for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item() num_selections = (topk_idx[i] != -1).sum().item()
...@@ -121,54 +212,104 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -121,54 +212,104 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
num_combine_comm_bytes += num_bf16_bytes * num_selections num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing # Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False)) avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us',
flush=True)
# Separate profiling # Separate profiling
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
group.barrier() group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook), kernel_names=('dispatch', 'combine'),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, barrier_comm_profiling=True,
suppress_kineto_output=True) suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook: if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us') f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True)
else: else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us') f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True)
return hash_value return hash_value
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int): def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
# The default setting of deepEP upstream is below: num_tokens, hidden = args.num_tokens, args.hidden
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 256 num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0: if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, buffer = deep_ep.Buffer(group,
num_qps_per_rank=num_experts // num_ranks, explicitly_destroy=True) num_rdma_bytes=num_rdma_bytes,
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1) low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=1)
do_pressure_test = False do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0): for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0: if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True) print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) ref_hash = test_main(num_tokens,
for i in range(20): hidden,
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}' num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
buffer.destroy() buffer.destroy()
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
if __name__ == '__main__': if __name__ == '__main__':
print("main start...")
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8 # TODO: buggy with `num_tokens=512`
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
import argparse
import random
import torch
import torch.distributed as dist
from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back
def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks = {
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch': 1,
'combine': 3,
'clean': 5
}
if rank in expected_masked_ranks:
# Rank already failed
return True
if api in failed_api_ranks.keys():
expected_masked_ranks.add(failed_api_ranks[api])
if failed_api_ranks[api] == rank:
print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True)
return True
return False
def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor,
expected_masked_ranks: Set[int]):
buffer.low_latency_query_mask_buffer(mask_status)
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
for _ in range(4 if use_logfmt else 0):
# NOTES: make more LogFMT casts and also with some BF16
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
expected_masked_ranks = set()
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
num_times += 1
for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x
if not dispatch_use_quant:
simulated_gemm_x = packed_recv_x.clone()
elif quant_group_size == 0:
simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape)
elif quant_group_size == 128:
simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
if not dispatch_use_quant:
recv_x = packed_recv_x[i]
elif quant_group_size == 0:
recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i])
elif quant_group_size == 128:
recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (
recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if quant_group_size != 0:
if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, ) if use_logfmt else (False, True, ):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
# if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x)
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
quant_type=2, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us',
flush=True)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'),
barrier_comm_profiling=True,
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=1)
do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment