# SPDX-License-Identifier: MIT import torch import torch.nn.functional as F import aiter from aiter.test_common import checkAllclose, perftest from aiter.ops.rmsnorm import head_rms_norm from aiter import dtypes import argparse # --------------------------------------------------------------------------- # Pure-PyTorch reference: head-wise RMS norm # --------------------------------------------------------------------------- def head_rms_norm_ref(input, weight, eps, head_dim): """ Reference implementation using only torch ops. input: [M, H*D] where H = num_heads, D = head_dim weight: [H*D] """ M = input.shape[0] H = input.shape[-1] // head_dim # reshape to [M, H, D] x = input.view(M, H, head_dim).float() w = weight.view(H, head_dim).float() # per-head RMS rms = torch.sqrt((x * x).mean(dim=-1, keepdim=True) + eps) # [M, H, 1] y = (x / rms) * w # [M, H, D] return y.view_as(input).to(input.dtype) # --------------------------------------------------------------------------- # Perf wrappers # --------------------------------------------------------------------------- @perftest() def run_ref(input, weight, eps, head_dim): return head_rms_norm_ref(input, weight, eps, head_dim) @perftest() def run_aiter(input, weight, eps, head_dim): return head_rms_norm(input, weight, eps, head_dim) # --------------------------------------------------------------------------- # Test & benchmark # --------------------------------------------------------------------------- def test_head_rms_norm(dtype, m, n, head_dim): """Correctness test: compare aiter head_rms_norm against torch reference.""" dim = (m, n) input = torch.randn(dim, dtype=dtype, device="cuda") weight = torch.randn(n, dtype=dtype, device="cuda") eps = 1e-5 ref_out, avg_ref = run_ref(input, weight, eps, head_dim) aiter_out, avg_aiter = run_aiter(input, weight, eps, head_dim) # cosine similarity cos_sim = F.cosine_similarity( ref_out.float().flatten(), aiter_out.float().flatten(), dim=0 ).item() # per-head cosine similarity (finer-grained check) H = n // head_dim head_cos = [] for h in range(H): head_ref = ref_out[:, h * head_dim : (h + 1) * head_dim].float().flatten() head_aiter = aiter_out[:, h * head_dim : (h + 1) * head_dim].float().flatten() head_cos.append(F.cosine_similarity(head_ref, head_aiter, dim=0).item()) msg = ( f"[perf] dim: {str(dim):<20}, head_dim: {head_dim}, " f"dtype: {dtype}, ref avg: {avg_ref:<8.2f} us, " f"aiter avg: {avg_aiter:<8.2f} us, uplift: {avg_aiter / avg_ref - 1:<5.1%}" ) print( f"[cos] dim: {str(dim):<20}, head_dim: {head_dim}, " f"dtype: {dtype}, cos(ref,aiter): {cos_sim:.8f}, " f"head_cos(avg): {sum(head_cos)/len(head_cos):.8f}" ) checkAllclose(ref_out, aiter_out, atol=0.02, rtol=0.002, msg=msg) def test_head_rms_norm_vs_global_rmsnorm(dtype, m, n, head_dim): """ Verify that when num_heads == 1 (head_dim == n), head_rms_norm is equivalent to standard rms_norm. """ dim = (m, n) input = torch.randn(dim, dtype=dtype, device="cuda") weight = torch.randn(n, dtype=dtype, device="cuda") eps = 1e-5 # head_rms_norm with head_dim == n (single head) head_out = head_rms_norm(input, weight, eps, head_dim) # standard rms_norm rms_out = aiter.rms_norm(input, weight, eps) cos_sim = F.cosine_similarity( head_out.float().flatten(), rms_out.float().flatten(), dim=0 ).item() print( f"[equiv-check] dim: {str(dim):<20}, dtype: {dtype}, " f"head_rms_norm vs rms_norm cos: {cos_sim:.8f}" ) checkAllclose(rms_out, head_out, atol=0.02, rtol=0.002, msg="head_rms_norm should match rms_norm when head_dim == hidden_dim") def test_head_rms_norm_gradient(dtype, m, n, head_dim): """Optional: verify autograd (requires backward kernel; not used in inference).""" if head_dim > n: return # skip invalid config dim = (m, n) input = torch.randn(dim, dtype=dtype, device="cuda", requires_grad=True) weight = torch.randn(n, dtype=dtype, device="cuda", requires_grad=True) eps = 1e-5 # forward out = head_rms_norm(input, weight, eps, head_dim) loss = out.sum() loss.backward() # gradients should be non-None and finite assert input.grad is not None, "input.grad is None" assert weight.grad is not None, "weight.grad is None" assert torch.isfinite(input.grad).all(), "input.grad has NaN/Inf" assert torch.isfinite(weight.grad).all(), "weight.grad has NaN/Inf" print( f"[grad] dim: {str(dim):<20}, head_dim: {head_dim}, dtype: {dtype}, " f"input.grad mean: {input.grad.mean().item():.8f}, " f"weight.grad mean: {weight.grad.mean().item():.8f}" ) # --------------------------------------------------------------------------- # Edge-case tests # --------------------------------------------------------------------------- def test_small_inputs(): """Test with very small inputs to catch boundary bugs.""" for dtype in [dtypes.fp16, dtypes.bf16]: for m in [1, 3]: for n in [128, 256]: for head_dim in [32, 64, 128]: if n % head_dim != 0: continue test_head_rms_norm(dtype, m, n, head_dim) def test_large_inputs(): """Test with larger, more realistic input sizes.""" for dtype in [dtypes.fp16, dtypes.bf16]: for m in [1, 16, 128, 1024]: for n in [4096, 8192]: for head_dim in [64, 128]: if n % head_dim != 0: continue test_head_rms_norm(dtype, m, n, head_dim) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- l_dtype = ["fp16", "bf16"] l_m = [1, 16, 64, 128, 256, 1024] l_n = [1024, 4096] l_head_dim = [64, 128] parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="Test and benchmark head_rms_norm (default: forward-only / inference)", ) parser.add_argument("-d", "--dtype", type=str, choices=l_dtype, default=None, help="Data type, e.g. -d bf16") parser.add_argument("-m", "--m", type=int, default=None, help="M (num_tokens)") parser.add_argument("-n", "--n", type=int, default=None, help="N (hidden_size)") parser.add_argument("--head_dim", type=int, default=None, help="Head dimension") parser.add_argument("--equiv", action="store_true", help="Run equivalence test against standard rms_norm") parser.add_argument("--grad", action="store_true", help="Run gradient test (off by default; head_rms_norm is inference-only)") parser.add_argument("--small", action="store_true", help="Run small-input edge-case tests") parser.add_argument("--large", action="store_true", help="Run large-input benchmark tests") parser.add_argument("--all", action="store_true", help="Run all tests including --grad") args = parser.parse_args() if args.dtype is None: l_dtype = [dtypes.d_dtypes[key] for key in l_dtype] else: l_dtype = [dtypes.d_dtypes[args.dtype]] if args.m is not None: l_m = [args.m] if args.n is not None: l_n = [args.n] if args.head_dim is not None: l_head_dim = [args.head_dim] run_all = args.all or (not args.small and not args.large and not args.equiv and not args.grad) if run_all or args.small: print("\n" + "=" * 60) print("SMALL-INPUT EDGE-CASE TESTS") print("=" * 60) test_small_inputs() if run_all or args.large: print("\n" + "=" * 60) print("LARGE-INPUT CORRECTNESS & PERF TESTS") print("=" * 60) for dtype in l_dtype: for m in l_m: for n in l_n: for hd in l_head_dim: if n % hd != 0: continue test_head_rms_norm(dtype, m, n, hd) if run_all or args.equiv: print("\n" + "=" * 60) print("EQUIVALENCE TEST: head_rms_norm vs rms_norm") print("=" * 60) for dtype in l_dtype: for m in [1, 16, 128]: for n in [256, 1024]: test_head_rms_norm_vs_global_rmsnorm(dtype, m, n, head_dim=n) # Gradient test is opt-in: head_rms_norm has forward CUDA only (no autograd kernel). if args.grad or args.all: print("\n" + "=" * 60) print("GRADIENT TEST (opt-in)") print("=" * 60) for dtype in [dtypes.fp16, dtypes.bf16]: for m in [4, 16]: for n in [1024]: for hd in [64, 128]: test_head_rms_norm_gradient(dtype, m, n, hd) print("\n✅ All head_rms_norm inference tests passed.")