import argparse import os import random import torch import torch.distributed as dist import socket from functools import partial from typing import Literal, Set import deep_ep from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]): # Simulates rank failure when the rank first calls the corresponding communication API failed_api_ranks = { # API -> rank to fail (rank fails when it first calls the corresponding communication API) 'dispatch': 1, 'combine': 3, 'clean': 5 } if rank in expected_masked_ranks: # Rank already failed return True if api in failed_api_ranks.keys(): expected_masked_ranks.add(failed_api_ranks[api]) if failed_api_ranks[api] == rank: print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True) return True return False def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor, expected_masked_ranks: Set[int]): buffer.low_latency_query_mask_buffer(mask_status) assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks def ceil_div(a, b): return (a + b - 1) // b def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, enable_dispatch_ll_layered: bool = False, enable_combine_overlap: bool = False, use_logfmt: bool = False, seed: int = 0): torch.manual_seed(seed + rank) random.seed(seed + rank) if rank == 0: print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}") assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \ "use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict" assert num_experts % num_ranks == 0 num_local_experts = num_experts // num_ranks # NOTES: the integers greater than 256 exceed the BF16 precision limit rank_offset = 128 assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) x_list = [x] for _ in range(4 if use_logfmt else 0): # NOTES: make more LogFMT casts and also with some BF16 x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random()) # NOTES: the last one is for performance testing # Most of the values in the perf case is lower than the threshold, casting most channels x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() # Randomly mask some positions for _ in range(10): topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) # For failure simulation and shrink testing mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda') expected_masked_ranks = set() # Check dispatch correctness do_check = True hash_value, num_times = 0, 0 for x_i, current_x in enumerate(x_list): for return_recv_hook in (False, True): if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop continue for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2 dispatch_use_quant = quant_type > 0 for fp8_round_scale in (False, True) if quant_type != 3 else (True, ): for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ): if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0): continue num_times += 1 for _ in range((num_times % 2) + 1): packed_recv_x, packed_recv_count, handle, event, hook = \ buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size, async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) hook() if return_recv_hook else event.current_stream_wait() packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x if not dispatch_use_quant: simulated_gemm_x = packed_recv_x.clone() elif quant_group_size == 0: simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape) elif quant_group_size == 128: simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) for i in range(num_local_experts if do_check else 0): expert_id = rank * num_local_experts + i if not dispatch_use_quant: recv_x = packed_recv_x[i] elif quant_group_size == 0: recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i]) elif quant_group_size == 128: recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i]) recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] # Check expert indices int_mask = (2 ** 32) - 1 num_valid_tokens = recv_count.item() assert num_valid_tokens == ( recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item( ), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}' if num_valid_tokens == 0: continue # Check received data if current_x is x: recv_x = recv_x[:num_valid_tokens] recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amax = recv_x[:, :-128].amax(dim=-1) if enable_dispatch_ll_layered or enable_combine_overlap: recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息 else: recv_src_info = recv_src_info[:num_valid_tokens] assert torch.equal(recv_x_amin, recv_x_amax) if dispatch_use_quant: assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) if quant_group_size != 0: if fp8_round_scale: assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 else: assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 for j in range(num_ranks): begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() if not fp8_round_scale: assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 if dispatch_use_quant: hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) else: hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) # Check combine correctness for zero_copy in (False, ) if use_logfmt else (False, True, ): if zero_copy: buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') if enable_combine_overlap: block_m, threshold, num_sms = 64, 10, 3 total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数 comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda') for i in range(num_local_experts): vaild_num = ceil_div(packed_recv_count[i], block_m) comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, packed_recv_count=packed_recv_count, comp_signal=comp_signal, block_m=block_m, threshold=threshold, num_sms=num_sms, async_finish=not return_recv_hook, zero_copy=zero_copy, return_recv_hook=return_recv_hook, out=out) else: combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, use_logfmt=use_logfmt, async_finish=not return_recv_hook, zero_copy=zero_copy, return_recv_hook=return_recv_hook, out=out) hook() if return_recv_hook else event.current_stream_wait() if do_check: diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) assert torch.isnan(combined_x).sum().item() == 0 # if not fp8_round_scale: assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}' hash_value ^= hash_tensor(combined_x) if rank == 0: print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ", f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass") print("deep_ep 全部正确性测试完成") if enable_dispatch_ll_layered or enable_combine_overlap: return hash_value # noinspection PyShadowingNames def large_gemm_with_hook(hook): mat_0 = torch.randn((8192, 8192), dtype=torch.float) mat_1 = torch.randn((8192, 8192), dtype=torch.float) mat_0 @ mat_1 hook() # noinspection PyShadowingNames def test_func(return_recv_hook: bool): recv_x, recv_count, handle, event, hook = \ buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, quant_type=2, quant_group_size=0, async_finish=False, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None # Calculate bandwidth num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 for i in range(num_tokens): num_selections = (topk_idx[i] != -1).sum().item() num_dispatch_comm_bytes += num_fp8_bytes * num_selections num_combine_comm_bytes += num_bf16_bytes * num_selections # Dispatch + combine testing avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) # Separate profiling for return_recv_hook in (False, True): group.barrier() dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1) if not return_recv_hook: print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) else: print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) return hash_value # noinspection PyUnboundLocalVariable,PyShadowingNames def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(args.rank, args.world_size, args.local_rank, args.dist_url) num_nodes = args.world_size // args.num_processes hostname = socket.gethostname() ip = socket.gethostbyname(hostname) print(f"rank={rank}, num_ranks={num_ranks}, num_nodes={num_nodes}, ip={ip}") num_tokens, hidden = args.num_tokens, args.hidden num_topk, num_experts = args.num_topk, args.num_experts enable_dispatch_ll_layered = args.enable_dispatch_ll_layered enable_combine_overlap = args.enable_combine_overlap if enable_dispatch_ll_layered: enable_combine_overlap = True num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts, enable_dispatch_ll_layered=enable_dispatch_ll_layered) if rank == 0: print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // num_ranks, allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True, allow_mnnvl=args.allow_mnnvl, enable_dispatch_ll_layered=enable_dispatch_ll_layered, enable_combine_overlap=enable_combine_overlap ) print("deep_ep 初始化完成") test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, use_logfmt=args.use_logfmt, enable_dispatch_ll_layered=enable_dispatch_ll_layered, enable_combine_overlap=enable_combine_overlap, seed=1) do_pressure_test = args.pressure_test for seed in range(int(1e9) if do_pressure_test else 0): if rank == 0: print(f'Testing with seed {seed} ...', flush=True) ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, use_logfmt=args.use_logfmt, enable_dispatch_ll_layered=enable_dispatch_ll_layered, enable_combine_overlap=enable_combine_overlap, seed=seed) for _ in range(20): assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, use_logfmt=args.use_logfmt, enable_dispatch_ll_layered=enable_dispatch_ll_layered, enable_combine_overlap=enable_combine_overlap, seed=seed) == ref_hash, f'Error: seed={seed}' # Destroy the buffer runtime and communication group buffer.destroy() dist.barrier() dist.destroy_process_group() if __name__ == '__main__': # TODO: you may modify NUMA binding for less CPU overhead parser = argparse.ArgumentParser(description='Test low-latency EP kernels') group = parser.add_argument_group(title='extra distributed args') group.add_argument('--rank', default=-int(os.getenv('OMPI_COMM_WORLD_RANK', '0')), type=int, help='node rank for distributed training') group.add_argument('--world-size', type=int, default=int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')), help='number of nodes for distributed training') group.add_argument('--local-rank', type=int, default=int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')), help='local rank passed from distributed launcher.') group.add_argument('--dist-url', help='Which master node url for distributed training.') parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)') parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)') parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication') parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing') parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') # 新版 sbo 需要的 parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization') parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase') args = parser.parse_args() if args.world_size > args.num_processes: test_loop(args.local_rank, args.num_processes, args)