# SPDX-License-Identifier: MIT import torch import torch.nn.functional as F import aiter from aiter.test_common import checkAllclose, perftest from aiter.ops.quant import moe_swiglu_dynamic_quant_wrapper from aiter import dtypes import argparse # --------------------------------------------------------------------------- # Pure-PyTorch reference: MoE SwiGLU + dynamic per-token int8 quantization # --------------------------------------------------------------------------- def moe_swiglu_dynamic_quant_ref(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start, beta=1.0): """ Reference implementation. scatter_tokens: [num_tokens, 2*d] — gate & up projections (fp16/bf16) smooth: [num_experts, d] — per-expert smooth scales (float32) experts_tokens_count: [num_experts] — token count per expert experts_tokens_start: [num_experts] — token start index per expert beta: float, unused Returns: output: [num_tokens, d] int8 scales: [num_tokens] float32 """ device = scatter_tokens.device num_tokens, d2 = scatter_tokens.shape d = d2 // 2 num_experts = smooth.shape[0] gate = scatter_tokens[:, :d].float() up = scatter_tokens[:, d:].float() act = F.silu(gate) * up # [num_tokens, d] expert_ids = torch.full((num_tokens,), -1, dtype=torch.long, device=device) for e in range(num_experts): start = experts_tokens_start[e].item() count = experts_tokens_count[e].item() if count > 0: expert_ids[start:start + count] = e smooth_f = smooth.float() for t in range(num_tokens): eid = expert_ids[t].item() if eid >= 0: act[t] = act[t] * smooth_f[eid] row_max = act.abs().max(dim=-1, keepdim=True).values.squeeze(-1) scales = row_max / 127.0 inv_scale = torch.where(row_max > 0, 127.0 / row_max, torch.zeros_like(row_max)) q = act * inv_scale.unsqueeze(-1) q = q.round().clamp(-128, 127).to(torch.int8) return q, scales @perftest() def run_ref(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start): return moe_swiglu_dynamic_quant_ref(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start) @perftest() def run_aiter(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start): return moe_swiglu_dynamic_quant_wrapper(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start) def make_moe_metadata(num_tokens, num_experts, device): """Generate experts_tokens_count and experts_tokens_start for testing.""" base = num_tokens // num_experts remainder = num_tokens % num_experts experts_tokens_count = torch.full((num_experts,), base, dtype=torch.int32, device=device) for i in range(remainder): experts_tokens_count[i] += 1 experts_tokens_start = torch.zeros(num_experts, dtype=torch.int32, device=device) for i in range(1, num_experts): experts_tokens_start[i] = experts_tokens_start[i - 1] + experts_tokens_count[i - 1] return experts_tokens_count, experts_tokens_start def test_correctness(dtype, num_tokens, d, num_experts): """Compare aiter output against PyTorch reference.""" device = "cuda" scatter_tokens = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device) smooth = torch.randn(num_experts, d, dtype=torch.float32, device=device) experts_tokens_count, experts_tokens_start = make_moe_metadata( num_tokens, num_experts, device) (ref_out, ref_scales), avg_ref = run_ref(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start) (aiter_out, aiter_scales), avg_aiter = run_aiter(scatter_tokens, smooth, experts_tokens_count, experts_tokens_start) out_match = (ref_out == aiter_out).float().mean().item() cos_out = F.cosine_similarity( ref_out.float().flatten(), aiter_out.float().flatten(), dim=0 ).item() cos_scale = F.cosine_similarity( ref_scales.flatten(), aiter_scales.flatten(), dim=0 ).item() msg = (f"[perf] tokens: {num_tokens:<6}, d: {d:<6}, experts: {num_experts:<4}, " f"dtype: {dtype}, ref: {avg_ref:<8.2f} us, aiter: {avg_aiter:<8.2f} us") print( f"[acc] tokens: {num_tokens:<6}, d: {d:<6}, experts: {num_experts:<4}, " f"dtype: {str(dtype):<8}, out_match: {out_match:.4f}, " f"cos_out: {cos_out:.8f}, cos_scale: {cos_scale:.8f}" ) checkAllclose(ref_out.float(), aiter_out.float(), atol=0.5, rtol=0.0, msg=msg + " [output]") checkAllclose(ref_scales, aiter_scales, atol=2e-4, rtol=2e-4, msg=msg + " [scales]") # Default smoke cases: edge + typical MoE shapes (6 cases × 2 dtypes = 12 runs) SMOKE_CASES = [ (1, 128, 1), (32, 256, 1), (256, 512, 4), ] def run_smoke_tests(dtypes_to_run): print("=" * 60) print("SMOKE CORRECTNESS TESTS") print("=" * 60) for dtype in dtypes_to_run: for num_tokens, d, num_experts in SMOKE_CASES: test_correctness(dtype, num_tokens, d, num_experts) def test_empty_experts(): """Some experts have 0 tokens.""" device = "cuda" dtype = dtypes.fp16 num_tokens, d, num_experts = 64, 256, 4 scatter_tokens = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device) smooth = torch.randn(num_experts, d, dtype=torch.float32, device=device) experts_tokens_count = torch.tensor([32, 32, 0, 0], dtype=torch.int32, device=device) experts_tokens_start = torch.tensor([0, 32, 64, 64], dtype=torch.int32, device=device) ref_out, ref_scales = moe_swiglu_dynamic_quant_ref( scatter_tokens, smooth, experts_tokens_count, experts_tokens_start) aiter_out, aiter_scales = moe_swiglu_dynamic_quant_wrapper( scatter_tokens, smooth, experts_tokens_count, experts_tokens_start) out_match = (ref_out == aiter_out).float().mean().item() print(f"[empty-experts] out_match: {out_match:.4f}") checkAllclose(ref_out.float(), aiter_out.float(), atol=0.5, rtol=0.0, msg="[empty-experts] [output]") checkAllclose(ref_scales, aiter_scales, atol=2e-4, rtol=2e-4, msg="[empty-experts] [scales]") # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- l_dtype = ["fp16", "bf16"] l_tokens = [32, 256, 1024] l_d = [128, 512, 1024] l_experts = [2, 4, 8] parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="Test moe_swiglu_dynamic_quant (default: small smoke suite)", ) parser.add_argument("-d", "--dtype", type=str, choices=l_dtype, default=None, help="Data type, e.g. -d fp16") parser.add_argument("--tokens", type=int, default=None, help="Number of tokens") parser.add_argument("--dim", type=int, default=None, help="Hidden dimension d") parser.add_argument("--experts", type=int, default=None, help="Number of experts") parser.add_argument("--all", action="store_true", help="Run extended grid (more shapes) plus --empty") parser.add_argument("--empty", action="store_true", help="Run empty-experts edge case") args = parser.parse_args() if args.dtype is None: dtypes_to_run = [dtypes.d_dtypes[key] for key in l_dtype] else: dtypes_to_run = [dtypes.d_dtypes[args.dtype]] if args.tokens is not None: l_tokens = [args.tokens] if args.dim is not None: l_d = [args.dim] if args.experts is not None: l_experts = [args.experts] custom_shape = args.tokens is not None or args.dim is not None or args.experts is not None if custom_shape: for dtype in dtypes_to_run: for num_tokens in l_tokens: for d in l_d: for num_experts in l_experts: if num_tokens < num_experts: continue test_correctness(dtype, num_tokens, d, num_experts) else: run_smoke_tests(dtypes_to_run) if args.empty or args.all: print("\n" + "=" * 60) print("EMPTY EXPERTS EDGE CASE") print("=" * 60) test_empty_experts() if args.all: print("\n" + "=" * 60) print("EXTENDED CORRECTNESS GRID") print("=" * 60) for dtype in dtypes_to_run: for num_tokens in l_tokens: for d in l_d: for num_experts in l_experts: if num_tokens < num_experts: continue if (num_tokens, d, num_experts) in SMOKE_CASES: continue test_correctness(dtype, num_tokens, d, num_experts) print("\n✅ All moe_swiglu_dynamic_quant tests passed.")