"vscode:/vscode.git/clone" did not exist on "f0b661b8fbb71386dec3d7aaa92bc7d318adad2f"
test_utils.py 4.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch


def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags):
    device = torch.device("cuda")
    dtype = torch.bfloat16

    seed = num_tokens * 10000 + hidden_dim
    gen_cpu = torch.Generator(device="cpu")
    gen_cpu.manual_seed(seed)
    gen_cuda = torch.Generator(device="cuda")
    gen_cuda.manual_seed(seed)

    if flags["fuse_silu_and_mul"]:
        effective_hidden_dim = hidden_dim * 2
    else:
        effective_hidden_dim = hidden_dim
    del hidden_dim

    if (masked_layout_mode := flags["masked_layout_mode"]) is not None:
        num_max_dispatch_tokens_per_rank = 768
        num_global_experts = 288
        num_local_experts, remainder = divmod(num_global_experts, num_ranks)
        assert remainder == 0

        # mimic DeepEP low_latency_dispatch output
        x = torch.randn(
            num_local_experts,
            num_max_dispatch_tokens_per_rank * num_ranks,
            effective_hidden_dim,
            device=device,
            dtype=dtype,
            generator=gen_cuda,
        )

        if masked_layout_mode == "balanced":
            masked_m = _compute_balanced_split(num_tokens, num_local_experts)
        elif masked_layout_mode == "imbalanced":
            masked_m = _compute_imbalanced_split(
                num_tokens, num_local_experts, gen_cpu=gen_cpu
            )
        elif masked_layout_mode == "extreme":
            masked_m = torch.tensor(
                [num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int
            )
        else:
            raise NotImplementedError
        print(f"{masked_layout_mode=} {masked_m=} {x.shape=}")

        masked_m = masked_m.to(device)

        return x, masked_m
    else:
        x = torch.randn(
            num_tokens,
            effective_hidden_dim,
            device=device,
            dtype=dtype,
            generator=gen_cuda,
        )
        x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10
        return x, None


def _compute_balanced_split(total: int, arr_len: int):
    base = total // arr_len
    remainder = total % arr_len
    ans = [base + 1 if i < remainder else base for i in range(arr_len)]
    assert sum(ans) == total
    return torch.tensor(ans, dtype=torch.int)


def _compute_imbalanced_split(
    total: int, arr_len: int, gen_cpu, dtype=torch.int
) -> list[int]:
    # can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
    noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3

    noise = noise_raw / noise_raw.sum()
    ans = (noise * total).round().to(dtype)

    diff = total - ans.sum().item()
    while diff != 0:
        idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item()
        if diff > 0:
            ans[idx] += 1
            diff -= 1
        elif diff < 0 and ans[idx] > 0:
            ans[idx] -= 1
            diff += 1

    assert sum(ans) == total
    return ans


def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor):
    assert (a.shape == b.shape) and (
        a.dtype == b.dtype
    ), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}"
    numel = a.numel()

    if a.dtype == torch.float8_e4m3fn:
        a_u8 = a.view(torch.uint8)
        b_u8 = b.view(torch.uint8)
        diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()

        count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
        count_tiny_diff = (diff_u8 == 1).sum().item()
        count_large_diff = (diff_u8 >= 2).sum().item()
    elif a.dtype == torch.int8:
        diff = (a.to(torch.int16) - a.to(torch.int16)).abs()
        count_diff_sign = ((a >= 0) & (b < 0)).sum().item()
        count_tiny_diff = (diff == 1).sum().item()
        count_large_diff = (diff >= 2).sum().item()
    else:
        raise NotImplementedError

    assert (
        (count_diff_sign == 0)
        and (count_large_diff == 0)
        and (
            (count_tiny_diff / numel < 0.005)
            or ((count_tiny_diff / numel < 0.04) and (numel <= 4096))
        )
    ), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}"