test_low_latency.py 18.2 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
import argparse
Chenggang Zhao's avatar
Chenggang Zhao committed
2
import random
3
4
import time
import os
Chenggang Zhao's avatar
Chenggang Zhao committed
5
6
import torch
import torch.distributed as dist
7
import numpy as np
Chenggang Zhao's avatar
Chenggang Zhao committed
8
from functools import partial
9
from typing import Optional
Chenggang Zhao's avatar
Chenggang Zhao committed
10
11
12
13
14
15

import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back


def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
16
              rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
17
              use_logfmt: bool = False, seed: int = 0, enable_diagnose: bool = False):
Chenggang Zhao's avatar
Chenggang Zhao committed
18
19
20
21
22
23
    torch.manual_seed(seed + rank)
    random.seed(seed + rank)

    assert num_experts % num_ranks == 0
    num_local_experts = num_experts // num_ranks

Chenggang Zhao's avatar
Chenggang Zhao committed
24
    # NOTES: the integers greater than 256 exceed the BF16 precision limit
Chenggang Zhao's avatar
Chenggang Zhao committed
25
26
27
28
29
    rank_offset = 128
    assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'

    x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
    x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
30
    x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
Chenggang Zhao's avatar
Chenggang Zhao committed
31
32
33
34
35
36
37
38
39
40
41
    scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
    topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
    topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()

    # Randomly mask some positions
    for i in range(10):
        topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1

    # Check dispatch correctness
    do_check = True
    hash_value, num_times = 0, 0
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    for current_x in (x, x_pure_rand):
        for return_recv_hook in (False, True):
            for dispatch_use_fp8 in (False, True):
                for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
                    for use_ue8m0 in (False, True) if round_scale else (False, ):
                        num_times += 1
                        for i in range((num_times % 2) + 1):
                            cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
                            packed_recv_x, packed_recv_count, handle, event, hook = \
                                buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
                                                            use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
                                                            cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
                                                            async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
                            hook() if return_recv_hook else event.current_stream_wait()
                        packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
                        simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
                            if dispatch_use_fp8 else packed_recv_x.clone()
                        all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
                        dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
                        for i in range(num_local_experts if do_check else 0):
                            expert_id = rank * num_local_experts + i
                            recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
                            recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]

                            # Check expert indices
                            int_mask = (2 ** 32) - 1
                            num_valid_tokens = recv_count.item()
                            assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
                            assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
                            assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'

                            # Check received data
                            if current_x is not x_pure_rand:
                                recv_x = recv_x[:num_valid_tokens]
                                recv_x_amin = recv_x[:, :-128].amin(dim=-1)
                                recv_src_info = recv_src_info[:num_valid_tokens]
                                assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
                                if round_scale:
                                    assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
                                else:
                                    assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
                                for j in range(num_ranks):
                                    begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
                                    if not round_scale:
                                        assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
                                    assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
                            if dispatch_use_fp8:
                                hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
                                hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
                            else:
                                hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])

                        # Check combine correctness
                        for zero_copy in (False, ) if use_logfmt else (False, True):
                            if zero_copy:
                                buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
                            out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
                            combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
                                                                                use_logfmt=use_logfmt,
                                                                                async_finish=not return_recv_hook, zero_copy=zero_copy,
                                                                                return_recv_hook=return_recv_hook, out=out)
                            hook() if return_recv_hook else event.current_stream_wait()
                            if do_check:
                                diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
                                assert torch.isnan(combined_x).sum().item() == 0
                                assert diff < (7e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {zero_copy=}'
                                hash_value ^= hash_tensor(combined_x)
Chenggang Zhao's avatar
Chenggang Zhao committed
109
110
111
112
113
114
115
116
117

    # noinspection PyShadowingNames
    def large_gemm_with_hook(hook):
        mat_0 = torch.randn((8192, 8192), dtype=torch.float)
        mat_1 = torch.randn((8192, 8192), dtype=torch.float)
        mat_0 @ mat_1
        hook()

    # noinspection PyShadowingNames
118
    def test_func(return_recv_hook: bool):
Chenggang Zhao's avatar
Chenggang Zhao committed
119
        recv_x, recv_count, handle, event, hook = \
120
            buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
121
                                        cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
Shifang Xu's avatar
Shifang Xu committed
122
                                        use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
Chenggang Zhao's avatar
Chenggang Zhao committed
123
124
        large_gemm_with_hook(hook) if return_recv_hook else None
        combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
125
                                                             use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)
Chenggang Zhao's avatar
Chenggang Zhao committed
126
127
        large_gemm_with_hook(hook) if return_recv_hook else None

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    # noinspection PyShadowingNames
    def test_diagnose(test_dispatch_slow: bool, slow_rank: int,
                      dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
                      combine_wait_recv_cost_stats: Optional[torch.Tensor] = None):
        if test_dispatch_slow:
            if rank == slow_rank:
                time.sleep(0.001)
            buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts,
                                        cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
                                        dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
                                        use_fp8=True, async_finish=False)
        else:
            if rank == slow_rank:
                time.sleep(0.001)
            buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
                                       use_logfmt=use_logfmt, return_recv_hook=False,
                                       combine_wait_recv_cost_stats=combine_wait_recv_cost_stats)
Chenggang Zhao's avatar
Chenggang Zhao committed
145
146
147
148
149
150
151
152
153
    # Calculate bandwidth
    num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
    num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
    for i in range(num_tokens):
        num_selections = (topk_idx[i] != -1).sum().item()
        num_dispatch_comm_bytes += num_fp8_bytes * num_selections
        num_combine_comm_bytes += num_bf16_bytes * num_selections

    # Dispatch + combine testing
154
    avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
Chenggang Zhao's avatar
Chenggang Zhao committed
155
156
157
158
159
160
    print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
          f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)

    # Separate profiling
    for return_recv_hook in (False, True):
        group.barrier()
161
        dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
Chenggang Zhao's avatar
Chenggang Zhao committed
162
                                             kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
Chenggang Zhao's avatar
Chenggang Zhao committed
163
                                             suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1)
Chenggang Zhao's avatar
Chenggang Zhao committed
164
165
        if not return_recv_hook:
            print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
166
                  f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
Chenggang Zhao's avatar
Chenggang Zhao committed
167
        else:
Chenggang Zhao's avatar
Chenggang Zhao committed
168
169
            print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
                  f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    # Diagnose test
    if enable_diagnose:
        def diagnose_matrix(
            mat, thres_col=3.0, thres_row=3.0, thres_point=5.0,
            suppress_points_in_strong_rowscols=True
        ):
            """
            mat: 2D numpy array, mat[i, j] = the waiting time of src i waiting for dst j to receive the token
            Returns abnormal columns/rows/points.
            suppress_points_in_strong_rowscols: whether to remove points located in already detected abnormal rows or columns
            """
            # 1. Check for abnormal columns
            col_means = mat.mean(axis=0)
            # z_col = (col_means - col_means.mean()) / (col_means.std() + 1e-8)
            z_col = col_means / (col_means.mean() + 1e-8)
            abnormal_cols = np.where(z_col > thres_col)[0].tolist()

            # 2. Check for abnormal rows
            row_means = mat.mean(axis=1)
            # z_row = (row_means - row_means.mean()) / (row_means.std() + 1e-8)
            z_row = row_means / (row_means.mean() + 1e-8)
            abnormal_rows = np.where(z_row > thres_row)[0].tolist()

            # 3. Check for abnormal single points
            # z_all = (mat - mat.mean()) / (mat.std() + 1e-8)
            z_all = mat / (mat.mean() + 1e-8)
            # Get all positions with z-score > threshold
            abnormal_points = [
                (i, j, mat[i, j], z_all[i, j])
                for i in range(mat.shape[0])
                for j in range(mat.shape[1])
                if z_all[i, j] > thres_point
            ]
            # Optionally remove points that are in already detected abnormal rows
            # or columns
            if suppress_points_in_strong_rowscols:
                abnormal_points = [
                    (i, j, v, z) for (i, j, v, z) in abnormal_points
                    if i not in abnormal_rows and j not in abnormal_cols
                ]
            # 4. Return for automatic processing
            return {
                'abnormal_cols': abnormal_cols,
                'abnormal_rows': abnormal_rows,
                'abnormal_points': abnormal_points
            }

        dispatch_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
        combine_wait_recv_cost_stats = torch.zeros((num_ranks, ), dtype=torch.int64, device='cuda')
        slow_rank = [0, 1]
        for i, test_dispatch_slow in enumerate([True, False]):
            bench(
                partial(
                    test_diagnose,
                    test_dispatch_slow=test_dispatch_slow,
                    slow_rank=slow_rank[i],
                    dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats,
                    combine_wait_recv_cost_stats=combine_wait_recv_cost_stats))
        stats_list = [dispatch_wait_recv_cost_stats, combine_wait_recv_cost_stats]
        stats_tensor = torch.stack(stats_list, dim=0)   # (N, num_ranks)
        # gather all ranks dispatch and combine diagnose stats to rank 0
        gather_tensor = [
            torch.zeros_like(
                torch.stack(
                    stats_list,
                    dim=0)) for _ in range(
                group.size())] if rank == 0 else None
        dist.gather(stats_tensor, gather_list=gather_tensor, group=group, dst=0)
        if rank == 0:
            stats_arr = torch.stack([it.cpu() for it in gather_tensor], dim=0).numpy()
            for i, name in enumerate(["Dispatch", "Combine"]):
                res = diagnose_matrix(stats_arr[:, i, :])
                assert slow_rank[i] in res[
                    'abnormal_cols'], f"[Diagnose] test failure, slow_rank {slow_rank[i]} not found in abnormal_cols {res['abnormal_cols']}"
                print(
                    f'[Diagnose] test successful!!! [{name}] slow_rank: {slow_rank[i]} diagnose info: {res}')
Chenggang Zhao's avatar
Chenggang Zhao committed
247
248
249
    return hash_value


Chenggang Zhao's avatar
Chenggang Zhao committed
250
251
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
Chenggang Zhao's avatar
Chenggang Zhao committed
252
    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
Chenggang Zhao's avatar
Chenggang Zhao committed
253
254
    num_tokens, hidden = args.num_tokens, args.hidden
    num_topk, num_experts = args.num_topk, args.num_experts
Chenggang Zhao's avatar
Chenggang Zhao committed
255
256
257
258
259

    num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
    if local_rank == 0:
        print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
    buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
260
                            num_qps_per_rank=num_experts // num_ranks,
261
                            allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True)
262
    test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
263
              use_logfmt=args.use_logfmt, seed=1, enable_diagnose=args.enable_diagnose)
Chenggang Zhao's avatar
Chenggang Zhao committed
264

265
    do_pressure_test = args.pressure_test
Chenggang Zhao's avatar
Chenggang Zhao committed
266
267
268
    for seed in range(int(1e9) if do_pressure_test else 0):
        if local_rank == 0:
            print(f'Testing with seed {seed} ...', flush=True)
269
270
        ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
                             use_logfmt=args.use_logfmt, seed=seed)
Chenggang Zhao's avatar
Chenggang Zhao committed
271
        for i in range(20):
272
273
            assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
                             use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
Chenggang Zhao's avatar
Chenggang Zhao committed
274

275
276
    # Destroy the buffer runtime and communication group
    buffer.destroy()
Shifang Xu's avatar
Shifang Xu committed
277
278
279
    dist.barrier()
    dist.destroy_process_group()

Chenggang Zhao's avatar
Chenggang Zhao committed
280
281
282

if __name__ == '__main__':
    # TODO: you may modify NUMA binding for less CPU overhead
283
    # TODO: buggy with `num_tokens=512`
Chenggang Zhao's avatar
Chenggang Zhao committed
284
    parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
285
286
287
288
289
290
291
292
293
294
    parser.add_argument('--num-processes', type=int, default=8,
                       help='Number of processes to spawn (default: 8)')
    parser.add_argument('--num-tokens', type=int, default=128,
                       help='Number of tokens (default: 128)')
    parser.add_argument('--hidden', type=int, default=7168,
                       help='Hidden dimension size (default: 7168)')
    parser.add_argument('--num-topk', type=int, default=8,
                       help='Number of top-k experts (default: 8)')
    parser.add_argument('--num-experts', type=int, default=288,
                       help='Number of experts (default: 288)')
295
296
297
298
    parser.add_argument('--disable-nvlink', action='store_true',
                        help='Whether to disable NVLink for testing')
    parser.add_argument('--use-logfmt', action='store_true',
                        help='Whether to test LogFMT combine')
299
300
    parser.add_argument("--pressure-test", action='store_true',
                        help='Whether to do pressure test')
301
302
    parser.add_argument('--enable-diagnose', action='store_true',
                        help='Whether to enable diagnose for testing')
303
304
305
306
    args = parser.parse_args()

    num_processes = args.num_processes
    torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)