test_moe_swiglu_dynamic_quant.py 8.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# SPDX-License-Identifier: MIT

import torch
import torch.nn.functional as F
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter.ops.quant import moe_swiglu_dynamic_quant_wrapper
from aiter import dtypes
import argparse


# ---------------------------------------------------------------------------
# Pure-PyTorch reference: MoE SwiGLU + dynamic per-token int8 quantization
# ---------------------------------------------------------------------------
def moe_swiglu_dynamic_quant_ref(scatter_tokens, smooth, experts_tokens_count,
                                  experts_tokens_start, beta=1.0):
    """
    Reference implementation.

    scatter_tokens:       [num_tokens, 2*d]  — gate & up projections (fp16/bf16)
    smooth:               [num_experts, d]   — per-expert smooth scales (float32)
    experts_tokens_count: [num_experts]       — token count per expert
    experts_tokens_start: [num_experts]       — token start index per expert
    beta:                 float, unused

    Returns:
        output: [num_tokens, d] int8
        scales: [num_tokens] float32
    """
    device = scatter_tokens.device
    num_tokens, d2 = scatter_tokens.shape
    d = d2 // 2
    num_experts = smooth.shape[0]

    gate = scatter_tokens[:, :d].float()
    up = scatter_tokens[:, d:].float()
    act = F.silu(gate) * up  # [num_tokens, d]

    expert_ids = torch.full((num_tokens,), -1, dtype=torch.long, device=device)
    for e in range(num_experts):
        start = experts_tokens_start[e].item()
        count = experts_tokens_count[e].item()
        if count > 0:
            expert_ids[start:start + count] = e

    smooth_f = smooth.float()
    for t in range(num_tokens):
        eid = expert_ids[t].item()
        if eid >= 0:
            act[t] = act[t] * smooth_f[eid]

    row_max = act.abs().max(dim=-1, keepdim=True).values.squeeze(-1)
    scales = row_max / 127.0

    inv_scale = torch.where(row_max > 0, 127.0 / row_max, torch.zeros_like(row_max))
    q = act * inv_scale.unsqueeze(-1)
    q = q.round().clamp(-128, 127).to(torch.int8)

    return q, scales


@perftest()
def run_ref(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start):
    return moe_swiglu_dynamic_quant_ref(scatter_tokens, smooth,
                                         experts_tokens_count, experts_tokens_start)


@perftest()
def run_aiter(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start):
    return moe_swiglu_dynamic_quant_wrapper(scatter_tokens, smooth,
                                             experts_tokens_count, experts_tokens_start)


def make_moe_metadata(num_tokens, num_experts, device):
    """Generate experts_tokens_count and experts_tokens_start for testing."""
    base = num_tokens // num_experts
    remainder = num_tokens % num_experts
    experts_tokens_count = torch.full((num_experts,), base, dtype=torch.int32, device=device)
    for i in range(remainder):
        experts_tokens_count[i] += 1

    experts_tokens_start = torch.zeros(num_experts, dtype=torch.int32, device=device)
    for i in range(1, num_experts):
        experts_tokens_start[i] = experts_tokens_start[i - 1] + experts_tokens_count[i - 1]

    return experts_tokens_count, experts_tokens_start


def test_correctness(dtype, num_tokens, d, num_experts):
    """Compare aiter output against PyTorch reference."""
    device = "cuda"
    scatter_tokens = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device)
    smooth = torch.randn(num_experts, d, dtype=torch.float32, device=device)

    experts_tokens_count, experts_tokens_start = make_moe_metadata(
        num_tokens, num_experts, device)

    (ref_out, ref_scales), avg_ref = run_ref(scatter_tokens, smooth,
                                              experts_tokens_count, experts_tokens_start)
    (aiter_out, aiter_scales), avg_aiter = run_aiter(scatter_tokens, smooth,
                                                      experts_tokens_count, experts_tokens_start)

    out_match = (ref_out == aiter_out).float().mean().item()
    cos_out = F.cosine_similarity(
        ref_out.float().flatten(), aiter_out.float().flatten(), dim=0
    ).item()
    cos_scale = F.cosine_similarity(
        ref_scales.flatten(), aiter_scales.flatten(), dim=0
    ).item()

    msg = (f"[perf] tokens: {num_tokens:<6}, d: {d:<6}, experts: {num_experts:<4}, "
           f"dtype: {dtype}, ref: {avg_ref:<8.2f} us, aiter: {avg_aiter:<8.2f} us")
    print(
        f"[acc]  tokens: {num_tokens:<6}, d: {d:<6}, experts: {num_experts:<4}, "
        f"dtype: {str(dtype):<8}, out_match: {out_match:.4f}, "
        f"cos_out: {cos_out:.8f}, cos_scale: {cos_scale:.8f}"
    )

    checkAllclose(ref_out.float(), aiter_out.float(), atol=0.5, rtol=0.0,
                  msg=msg + " [output]")
    checkAllclose(ref_scales, aiter_scales, atol=2e-4, rtol=2e-4,
                  msg=msg + " [scales]")


# Default smoke cases: edge + typical MoE shapes (6 cases × 2 dtypes = 12 runs)
SMOKE_CASES = [
    (1, 128, 1),
    (32, 256, 1),
    (256, 512, 4),
]


def run_smoke_tests(dtypes_to_run):
    print("=" * 60)
    print("SMOKE CORRECTNESS TESTS")
    print("=" * 60)
    for dtype in dtypes_to_run:
        for num_tokens, d, num_experts in SMOKE_CASES:
            test_correctness(dtype, num_tokens, d, num_experts)


def test_empty_experts():
    """Some experts have 0 tokens."""
    device = "cuda"
    dtype = dtypes.fp16
    num_tokens, d, num_experts = 64, 256, 4

    scatter_tokens = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device)
    smooth = torch.randn(num_experts, d, dtype=torch.float32, device=device)
    experts_tokens_count = torch.tensor([32, 32, 0, 0], dtype=torch.int32, device=device)
    experts_tokens_start = torch.tensor([0, 32, 64, 64], dtype=torch.int32, device=device)

    ref_out, ref_scales = moe_swiglu_dynamic_quant_ref(
        scatter_tokens, smooth, experts_tokens_count, experts_tokens_start)
    aiter_out, aiter_scales = moe_swiglu_dynamic_quant_wrapper(
        scatter_tokens, smooth, experts_tokens_count, experts_tokens_start)

    out_match = (ref_out == aiter_out).float().mean().item()
    print(f"[empty-experts] out_match: {out_match:.4f}")
    checkAllclose(ref_out.float(), aiter_out.float(), atol=0.5, rtol=0.0,
                  msg="[empty-experts] [output]")
    checkAllclose(ref_scales, aiter_scales, atol=2e-4, rtol=2e-4,
                  msg="[empty-experts] [scales]")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
l_dtype = ["fp16", "bf16"]
l_tokens = [32, 256, 1024]
l_d = [128, 512, 1024]
l_experts = [2, 4, 8]

parser = argparse.ArgumentParser(
    formatter_class=argparse.RawTextHelpFormatter,
    description="Test moe_swiglu_dynamic_quant (default: small smoke suite)",
)
parser.add_argument("-d", "--dtype", type=str, choices=l_dtype, default=None,
                    help="Data type, e.g. -d fp16")
parser.add_argument("--tokens", type=int, default=None, help="Number of tokens")
parser.add_argument("--dim", type=int, default=None, help="Hidden dimension d")
parser.add_argument("--experts", type=int, default=None, help="Number of experts")
parser.add_argument("--all", action="store_true",
                    help="Run extended grid (more shapes) plus --empty")
parser.add_argument("--empty", action="store_true", help="Run empty-experts edge case")

args = parser.parse_args()

if args.dtype is None:
    dtypes_to_run = [dtypes.d_dtypes[key] for key in l_dtype]
else:
    dtypes_to_run = [dtypes.d_dtypes[args.dtype]]
if args.tokens is not None:
    l_tokens = [args.tokens]
if args.dim is not None:
    l_d = [args.dim]
if args.experts is not None:
    l_experts = [args.experts]

custom_shape = args.tokens is not None or args.dim is not None or args.experts is not None

if custom_shape:
    for dtype in dtypes_to_run:
        for num_tokens in l_tokens:
            for d in l_d:
                for num_experts in l_experts:
                    if num_tokens < num_experts:
                        continue
                    test_correctness(dtype, num_tokens, d, num_experts)
else:
    run_smoke_tests(dtypes_to_run)

    if args.empty or args.all:
        print("\n" + "=" * 60)
        print("EMPTY EXPERTS EDGE CASE")
        print("=" * 60)
        test_empty_experts()

    if args.all:
        print("\n" + "=" * 60)
        print("EXTENDED CORRECTNESS GRID")
        print("=" * 60)
        for dtype in dtypes_to_run:
            for num_tokens in l_tokens:
                for d in l_d:
                    for num_experts in l_experts:
                        if num_tokens < num_experts:
                            continue
                        if (num_tokens, d, num_experts) in SMOKE_CASES:
                            continue
                        test_correctness(dtype, num_tokens, d, num_experts)

print("\n✅ All moe_swiglu_dynamic_quant tests passed.")