test_low_latency.py 9.75 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
4
5
6
7
8
9
10
import random
import torch
import torch.distributed as dist
from functools import partial

import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back


def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
lishen's avatar
lishen committed
11
              rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0):
Chenggang Zhao's avatar
Chenggang Zhao committed
12
13
14
15
16
17
    torch.manual_seed(seed + rank)
    random.seed(seed + rank)

    assert num_experts % num_ranks == 0
    num_local_experts = num_experts // num_ranks

lishen's avatar
lishen committed
18
    # NOTES: the integers greater than 256 exceeds the BF16 precision limit
Chenggang Zhao's avatar
Chenggang Zhao committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    rank_offset = 128
    assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'

    x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
    x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
    scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
    topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
    topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
    # Randomly mask some positions
    for i in range(10):
        topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1

    # Check dispatch correctness
    do_check = True
    hash_value, num_times = 0, 0
lishen's avatar
lishen committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    for return_recv_hook in (False, True):
        for dispatch_use_fp8 in (False, True):
            num_times += 1
            for i in range((num_times % 2) + 1):
                packed_recv_x, packed_recv_count, handle, event, hook = \
                    buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
                                                async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
                hook() if return_recv_hook else event.current_stream_wait()

                # print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
                # return
            packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
            simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
                if dispatch_use_fp8 else packed_recv_x.clone()
            # print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
            # print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
            # print(f"simulated_gemm_x{simulated_gemm_x.cpu()}")
            all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
            dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
            for i in range(num_local_experts if do_check else 0):
                expert_id = rank * num_local_experts + i
                recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
                recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]

                # Check expert indices
                int_mask = (2 ** 32) - 1
                num_valid_tokens = recv_count.item()
                assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
                assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'

                # Check received data
                recv_x = recv_x[:num_valid_tokens]
                recv_x_amin = recv_x[:, :-128].amin(dim=-1)
                recv_src_info = recv_src_info[:num_valid_tokens]
                assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
                assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
                for j in range(num_ranks):
                    begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
                    assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
                    assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
                if dispatch_use_fp8:
                    hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
                    hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
                else:
                    hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])

            # Check combine correctness
            for zero_copy in (False, True):
                if zero_copy:
                    buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
                out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
                combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
                                                                     async_finish=not return_recv_hook,
                                                                     return_recv_hook=return_recv_hook, out=out)
                hook() if return_recv_hook else event.current_stream_wait()
                if do_check:
                    diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
                    assert torch.isnan(combined_x).sum().item() == 0
                    assert diff < 1e-5, f'Error: diff={diff}'
                    hash_value ^= hash_tensor(combined_x)
Chenggang Zhao's avatar
Chenggang Zhao committed
95
96
97
98
99
100
101
102
103

    # noinspection PyShadowingNames
    def large_gemm_with_hook(hook):
        mat_0 = torch.randn((8192, 8192), dtype=torch.float)
        mat_1 = torch.randn((8192, 8192), dtype=torch.float)
        mat_0 @ mat_1
        hook()

    # noinspection PyShadowingNames
lishen's avatar
lishen committed
104
    def test_func(zero_copy: bool, return_recv_hook: bool):
Chenggang Zhao's avatar
Chenggang Zhao committed
105
        recv_x, recv_count, handle, event, hook = \
lishen's avatar
lishen committed
106
107
            buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
                                        async_finish=False, return_recv_hook=return_recv_hook)
Chenggang Zhao's avatar
Chenggang Zhao committed
108
        large_gemm_with_hook(hook) if return_recv_hook else None
lishen's avatar
lishen committed
109
110
        if zero_copy:
            buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
Chenggang Zhao's avatar
Chenggang Zhao committed
111
        combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
lishen's avatar
lishen committed
112
                                                             zero_copy=zero_copy, return_recv_hook=return_recv_hook)
Chenggang Zhao's avatar
Chenggang Zhao committed
113
114
115
116
117
118
119
120
        large_gemm_with_hook(hook) if return_recv_hook else None

    # Calculate bandwidth
    num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
    num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
    for i in range(num_tokens):
        num_selections = (topk_idx[i] != -1).sum().item()
        num_dispatch_comm_bytes += num_fp8_bytes * num_selections
lishen's avatar
lishen committed
121
        num_combine_comm_bytes += num_bf16_bytes * num_selections
Chenggang Zhao's avatar
Chenggang Zhao committed
122
123

    # Dispatch + combine testing
lishen's avatar
lishen committed
124
    avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False))
Chenggang Zhao's avatar
Chenggang Zhao committed
125
126
127
128
129
130
    print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
          f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)

    # Separate profiling
    for return_recv_hook in (False, True):
        group.barrier()
lishen's avatar
lishen committed
131
132

        dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
Chenggang Zhao's avatar
Chenggang Zhao committed
133
                                             kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
lishen's avatar
lishen committed
134
                                             suppress_kineto_output=True)
Chenggang Zhao's avatar
Chenggang Zhao committed
135
136
        if not return_recv_hook:
            print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
lishen's avatar
lishen committed
137
                  f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
Chenggang Zhao's avatar
Chenggang Zhao committed
138
        else:
lishen's avatar
lishen committed
139
140
141
            print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | '
                  f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us')

Chenggang Zhao's avatar
Chenggang Zhao committed
142
143
144
    return hash_value


lishen's avatar
lishen committed
145
146
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
Chenggang Zhao's avatar
Chenggang Zhao committed
147
    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
lishen's avatar
lishen committed
148
149
    # The default setting of deepEP upstream is below: 
    num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 256
Chenggang Zhao's avatar
Chenggang Zhao committed
150
151
152
153
154

    num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
    if local_rank == 0:
        print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
    buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
lishen's avatar
lishen committed
155
156
                            num_qps_per_rank=num_experts // num_ranks, explicitly_destroy=True)
    test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
Chenggang Zhao's avatar
Chenggang Zhao committed
157

lishen's avatar
lishen committed
158
    do_pressure_test = False
Chenggang Zhao's avatar
Chenggang Zhao committed
159
160
161
    for seed in range(int(1e9) if do_pressure_test else 0):
        if local_rank == 0:
            print(f'Testing with seed {seed} ...', flush=True)
lishen's avatar
lishen committed
162
        ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed)
Chenggang Zhao's avatar
Chenggang Zhao committed
163
        for i in range(20):
lishen's avatar
lishen committed
164
            assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
Chenggang Zhao's avatar
Chenggang Zhao committed
165
166

if __name__ == '__main__':
lishen's avatar
lishen committed
167
    print("main start...")
Chenggang Zhao's avatar
Chenggang Zhao committed
168
    # TODO: you may modify NUMA binding for less CPU overhead
lishen's avatar
lishen committed
169
170
    num_processes = 8
    torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)