"csrc/gfx93/prefill/vscode:/vscode.git/clone" did not exist on "aec174740c3167b075f18852d012a2bcfdc18878"
test_low_latency.py 23.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
import os
import random
import torch
import torch.distributed as dist
import socket
from functools import partial
from typing import Literal, Set

import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back


def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
    # Simulates rank failure when the rank first calls the corresponding communication API
    failed_api_ranks = {
        # API -> rank to fail (rank fails when it first calls the corresponding communication API)
        'dispatch': 1,
        'combine': 3,
        'clean': 5
    }
    if rank in expected_masked_ranks:
        # Rank already failed
        return True
    if api in failed_api_ranks.keys():
        expected_masked_ranks.add(failed_api_ranks[api])
        if failed_api_ranks[api] == rank:
            print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True)
            return True
    return False


def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor,
                                expected_masked_ranks: Set[int]):
    buffer.low_latency_query_mask_buffer(mask_status)
    assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks


39
40
41
42
def ceil_div(a, b):
    return (a + b - 1) // b


43
44
45
46
47
48
49
50
def test_main(num_tokens: int,
              hidden: int,
              num_experts: int,
              num_topk: int,
              rank: int,
              num_ranks: int,
              group: dist.ProcessGroup,
              buffer: deep_ep.Buffer,
51
52
              enable_dispatch_ll_layered: bool = False,
              enable_combine_overlap: bool = False,
53
54
55
56
              use_logfmt: bool = False,
              seed: int = 0):
    torch.manual_seed(seed + rank)
    random.seed(seed + rank)
57
58
    if rank == 0:
        print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
59

60
61
    assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
        "use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    assert num_experts % num_ranks == 0
    num_local_experts = num_experts // num_ranks

    # NOTES: the integers greater than 256 exceed the BF16 precision limit
    rank_offset = 128
    assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'

    x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
    x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
    x_list = [x]
    for _ in range(4 if use_logfmt else 0):
        # NOTES: make more LogFMT casts and also with some BF16
        x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
    # NOTES: the last one is for performance testing
    # Most of the values in the perf case is lower than the threshold, casting most channels
    x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)

    scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
    topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
    topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()

    # Randomly mask some positions
    for _ in range(10):
        topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1

    all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
    dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)

    # For failure simulation and shrink testing
    mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
    expected_masked_ranks = set()

    # Check dispatch correctness
    do_check = True
    hash_value, num_times = 0, 0
    for x_i, current_x in enumerate(x_list):
        for return_recv_hook in (False, True):
99
100
101
            if enable_combine_overlap and (not return_recv_hook):  # return_recv_hook 为False 时,不能启用 overlop
                continue

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
                dispatch_use_quant = quant_type > 0
                for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
                    for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
                        if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
                            continue

                        num_times += 1
                        for _ in range((num_times % 2) + 1):
                            packed_recv_x, packed_recv_count, handle, event, hook = \
                                buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
                                                            quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
                                                            async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
                            hook() if return_recv_hook else event.current_stream_wait()
                        packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x
                        if not dispatch_use_quant:
                            simulated_gemm_x = packed_recv_x.clone()
                        elif quant_group_size == 0:
                            simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape)
                        elif quant_group_size == 128:
                            simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
                        for i in range(num_local_experts if do_check else 0):
                            expert_id = rank * num_local_experts + i
                            if not dispatch_use_quant:
                                recv_x = packed_recv_x[i]
                            elif quant_group_size == 0:
                                recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i])
                            elif quant_group_size == 128:
                                recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i])
                            recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]

                            # Check expert indices
                            int_mask = (2 ** 32) - 1
                            num_valid_tokens = recv_count.item()
                            assert num_valid_tokens == (
                                    recv_layout_range
                                    & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
                            assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
                            ), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'

                            if num_valid_tokens == 0:
                                continue
                            # Check received data
                            if current_x is x:
                                recv_x = recv_x[:num_valid_tokens]
                                recv_x_amin = recv_x[:, :-128].amin(dim=-1)
                                recv_x_amax = recv_x[:, :-128].amax(dim=-1)
149
150
151
152
153
154

                                if enable_dispatch_ll_layered or enable_combine_overlap:
                                    recv_src_info = recv_src_info[:num_valid_tokens] & int_mask  # 掩掉多余的信息
                                else:
                                    recv_src_info = recv_src_info[:num_valid_tokens]

155
                                assert torch.equal(recv_x_amin, recv_x_amax)
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
                                if dispatch_use_quant:
                                    assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007

                                assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
                                if quant_group_size != 0:
                                    if fp8_round_scale:
                                        assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
                                    else:
                                        assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
                                    for j in range(num_ranks):
                                        begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
                                        if not fp8_round_scale:
                                            assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
                                            assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
171

172
173
174
175
176
177
178
179
180
181
182
                            if dispatch_use_quant:
                                hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
                                hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
                            else:
                                hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])

                        # Check combine correctness
                        for zero_copy in (False, ) if use_logfmt else (False, True, ):
                            if zero_copy:
                                buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
                            out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                            if enable_combine_overlap:
                                block_m, threshold, num_sms = 64, 10, 3
                                total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m)  # 每个本地专家 总的信号数
                                comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')

                                for i in range(num_local_experts):
                                    vaild_num = ceil_div(packed_recv_count[i], block_m)
                                    comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
                                combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
                                                                                     topk_idx,
                                                                                     topk_weights,
                                                                                     handle,
                                                                                     packed_recv_count=packed_recv_count,
                                                                                     comp_signal=comp_signal,
                                                                                     block_m=block_m,
                                                                                     threshold=threshold,
                                                                                     num_sms=num_sms,
                                                                                     async_finish=not return_recv_hook,
                                                                                     zero_copy=zero_copy,
                                                                                     return_recv_hook=return_recv_hook,
                                                                                     out=out)
                            else:
                                combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
                                                                                     topk_idx,
                                                                                     topk_weights,
                                                                                     handle,
                                                                                     use_logfmt=use_logfmt,
                                                                                     async_finish=not return_recv_hook,
                                                                                     zero_copy=zero_copy,
                                                                                     return_recv_hook=return_recv_hook,
                                                                                     out=out)

215
216
217
218
219
220
221
222
223
224
225
                            hook() if return_recv_hook else event.current_stream_wait()
                            if do_check:
                                diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
                                assert torch.isnan(combined_x).sum().item() == 0
                                # if not fp8_round_scale:
                                assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
                                hash_value ^= hash_tensor(combined_x)

                        if rank == 0:
                            print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ", 
                                  f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
226
227
228
229

    print("deep_ep 全部正确性测试完成")
    if enable_dispatch_ll_layered or enable_combine_overlap:
        return hash_value
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

    # noinspection PyShadowingNames
    def large_gemm_with_hook(hook):
        mat_0 = torch.randn((8192, 8192), dtype=torch.float)
        mat_1 = torch.randn((8192, 8192), dtype=torch.float)
        mat_0 @ mat_1
        hook()

    # noinspection PyShadowingNames
    def test_func(return_recv_hook: bool):
        recv_x, recv_count, handle, event, hook = \
            buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
                                        quant_type=2, quant_group_size=0,
                                        async_finish=False, return_recv_hook=return_recv_hook)
        large_gemm_with_hook(hook) if return_recv_hook else None
        combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
                                                             topk_idx,
                                                             topk_weights,
                                                             handle,
                                                             use_logfmt=use_logfmt,
                                                             return_recv_hook=return_recv_hook)
        large_gemm_with_hook(hook) if return_recv_hook else None

    # Calculate bandwidth
    num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
    num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
    num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
    for i in range(num_tokens):
        num_selections = (topk_idx[i] != -1).sum().item()
        num_dispatch_comm_bytes += num_fp8_bytes * num_selections
        num_combine_comm_bytes += num_bf16_bytes * num_selections

    # Dispatch + combine testing
    avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
    print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
          f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us',
          flush=True)

    # Separate profiling
    for return_recv_hook in (False, True):
        group.barrier()
        dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
                                             kernel_names=('dispatch', 'combine'),
                                             barrier_comm_profiling=True,
                                             suppress_kineto_output=True,
                                             num_kernels_per_period=2 if return_recv_hook else 1)
        if not return_recv_hook:
            print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
                  f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
                  flush=True)
        else:
            print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
                  f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
                  flush=True)
    return hash_value


# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
    rank, num_ranks, group = init_dist(args.rank, args.world_size, args.local_rank, args.dist_url)
    num_nodes = args.world_size // args.num_processes

    hostname = socket.gethostname()
    ip = socket.gethostbyname(hostname)
    print(f"rank={rank}, num_ranks={num_ranks}, num_nodes={num_nodes}, ip={ip}")

    num_tokens, hidden = args.num_tokens, args.hidden
    num_topk, num_experts = args.num_topk, args.num_experts

299
300
301
302
    enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
    enable_combine_overlap = args.enable_combine_overlap
    if enable_dispatch_ll_layered:
        enable_combine_overlap = True
303

304
305
    num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
                                                                   enable_dispatch_ll_layered=enable_dispatch_ll_layered)
306
307
308
309
310
311
312
313
    if rank == 0:
        print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
    buffer = deep_ep.Buffer(group,
                            num_rdma_bytes=num_rdma_bytes,
                            low_latency_mode=True,
                            num_qps_per_rank=num_experts // num_ranks,
                            allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
                            explicitly_destroy=True,
314
315
316
317
318
                            allow_mnnvl=args.allow_mnnvl,
                            enable_dispatch_ll_layered=enable_dispatch_ll_layered,
                            enable_combine_overlap=enable_combine_overlap
                            )
    print("deep_ep 初始化完成")
319
320
321
322
323
324
325
326
327
    test_main(num_tokens,
              hidden,
              num_experts,
              num_topk,
              rank,
              num_ranks,
              group,
              buffer,
              use_logfmt=args.use_logfmt,
328
329
              enable_dispatch_ll_layered=enable_dispatch_ll_layered,
              enable_combine_overlap=enable_combine_overlap,
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
              seed=1)

    do_pressure_test = args.pressure_test
    for seed in range(int(1e9) if do_pressure_test else 0):
        if rank == 0:
            print(f'Testing with seed {seed} ...', flush=True)
        ref_hash = test_main(num_tokens,
                             hidden,
                             num_experts,
                             num_topk,
                             rank,
                             num_ranks,
                             group,
                             buffer,
                             use_logfmt=args.use_logfmt,
345
346
                             enable_dispatch_ll_layered=enable_dispatch_ll_layered,
                             enable_combine_overlap=enable_combine_overlap,
347
348
349
350
351
352
353
354
355
356
357
                             seed=seed)
        for _ in range(20):
            assert test_main(num_tokens,
                             hidden,
                             num_experts,
                             num_topk,
                             rank,
                             num_ranks,
                             group,
                             buffer,
                             use_logfmt=args.use_logfmt,
358
359
                             enable_dispatch_ll_layered=enable_dispatch_ll_layered,
                             enable_combine_overlap=enable_combine_overlap,
360
361
362
363
364
365
366
367
368
369
370
                             seed=seed) == ref_hash, f'Error: seed={seed}'

    # Destroy the buffer runtime and communication group
    buffer.destroy()
    dist.barrier()
    dist.destroy_process_group()


if __name__ == '__main__':
    # TODO: you may modify NUMA binding for less CPU overhead
    parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    group = parser.add_argument_group(title='extra distributed args')
    group.add_argument('--rank', default=-int(os.getenv('OMPI_COMM_WORLD_RANK', '0')), type=int,
                       help='node rank for distributed training')
    group.add_argument('--world-size', type=int, default=int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')),
                       help='number of nodes for distributed training')
    group.add_argument('--local-rank', type=int, default=int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')),
                       help='local rank passed from distributed launcher.')
    group.add_argument('--dist-url',
                       help='Which master node url for distributed training.')

    parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
    parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
    parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
    parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
    parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)')
    parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
    parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
    parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
    parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
    parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
392
393
394
395
    # 新版 sbo 需要的
    parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
    parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')

396
397
398
399
    args = parser.parse_args()

    if args.world_size > args.num_processes:
        test_loop(args.local_rank, args.num_processes, args)