test_head_rms_norm.py 8.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# SPDX-License-Identifier: MIT
 
import torch
import torch.nn.functional as F
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter.ops.rmsnorm import head_rms_norm
from aiter import dtypes
import argparse


# ---------------------------------------------------------------------------
# Pure-PyTorch reference: head-wise RMS norm
# ---------------------------------------------------------------------------
def head_rms_norm_ref(input, weight, eps, head_dim):
    """
    Reference implementation using only torch ops.
    input:  [M, H*D]  where H = num_heads, D = head_dim
    weight: [H*D]
    """
    M = input.shape[0]
    H = input.shape[-1] // head_dim
    # reshape to [M, H, D]
    x = input.view(M, H, head_dim).float()
    w = weight.view(H, head_dim).float()

    # per-head RMS
    rms = torch.sqrt((x * x).mean(dim=-1, keepdim=True) + eps)  # [M, H, 1]
    y = (x / rms) * w  # [M, H, D]

    return y.view_as(input).to(input.dtype)


# ---------------------------------------------------------------------------
# Perf wrappers
# ---------------------------------------------------------------------------
@perftest()
def run_ref(input, weight, eps, head_dim):
    return head_rms_norm_ref(input, weight, eps, head_dim)


@perftest()
def run_aiter(input, weight, eps, head_dim):
    return head_rms_norm(input, weight, eps, head_dim)


# ---------------------------------------------------------------------------
# Test & benchmark
# ---------------------------------------------------------------------------
def test_head_rms_norm(dtype, m, n, head_dim):
    """Correctness test: compare aiter head_rms_norm against torch reference."""
    dim = (m, n)
    input = torch.randn(dim, dtype=dtype, device="cuda")
    weight = torch.randn(n, dtype=dtype, device="cuda")
    eps = 1e-5

    ref_out, avg_ref = run_ref(input, weight, eps, head_dim)
    aiter_out, avg_aiter = run_aiter(input, weight, eps, head_dim)

    # cosine similarity
    cos_sim = F.cosine_similarity(
        ref_out.float().flatten(), aiter_out.float().flatten(), dim=0
    ).item()

    # per-head cosine similarity (finer-grained check)
    H = n // head_dim
    head_cos = []
    for h in range(H):
        head_ref = ref_out[:, h * head_dim : (h + 1) * head_dim].float().flatten()
        head_aiter = aiter_out[:, h * head_dim : (h + 1) * head_dim].float().flatten()
        head_cos.append(F.cosine_similarity(head_ref, head_aiter, dim=0).item())

    msg = (
        f"[perf] dim: {str(dim):<20}, head_dim: {head_dim}, "
        f"dtype: {dtype}, ref avg: {avg_ref:<8.2f} us, "
        f"aiter avg: {avg_aiter:<8.2f} us, uplift: {avg_aiter / avg_ref - 1:<5.1%}"
    )
    print(
        f"[cos] dim: {str(dim):<20}, head_dim: {head_dim}, "
        f"dtype: {dtype}, cos(ref,aiter): {cos_sim:.8f}, "
        f"head_cos(avg): {sum(head_cos)/len(head_cos):.8f}"
    )

    checkAllclose(ref_out, aiter_out, atol=0.02, rtol=0.002, msg=msg)


def test_head_rms_norm_vs_global_rmsnorm(dtype, m, n, head_dim):
    """
    Verify that when num_heads == 1 (head_dim == n), head_rms_norm is 
    equivalent to standard rms_norm.
    """
    dim = (m, n)
    input = torch.randn(dim, dtype=dtype, device="cuda")
    weight = torch.randn(n, dtype=dtype, device="cuda")
    eps = 1e-5

    # head_rms_norm with head_dim == n (single head)
    head_out = head_rms_norm(input, weight, eps, head_dim)

    # standard rms_norm
    rms_out = aiter.rms_norm(input, weight, eps)

    cos_sim = F.cosine_similarity(
        head_out.float().flatten(), rms_out.float().flatten(), dim=0
    ).item()
    print(
        f"[equiv-check] dim: {str(dim):<20}, dtype: {dtype}, "
        f"head_rms_norm vs rms_norm cos: {cos_sim:.8f}"
    )
    checkAllclose(rms_out, head_out, atol=0.02, rtol=0.002,
                  msg="head_rms_norm should match rms_norm when head_dim == hidden_dim")


def test_head_rms_norm_gradient(dtype, m, n, head_dim):
    """Optional: verify autograd (requires backward kernel; not used in inference)."""
    if head_dim > n:
        return  # skip invalid config

    dim = (m, n)
    input = torch.randn(dim, dtype=dtype, device="cuda", requires_grad=True)
    weight = torch.randn(n, dtype=dtype, device="cuda", requires_grad=True)
    eps = 1e-5

    # forward
    out = head_rms_norm(input, weight, eps, head_dim)
    loss = out.sum()
    loss.backward()

    # gradients should be non-None and finite
    assert input.grad is not None, "input.grad is None"
    assert weight.grad is not None, "weight.grad is None"
    assert torch.isfinite(input.grad).all(), "input.grad has NaN/Inf"
    assert torch.isfinite(weight.grad).all(), "weight.grad has NaN/Inf"

    print(
        f"[grad] dim: {str(dim):<20}, head_dim: {head_dim}, dtype: {dtype}, "
        f"input.grad mean: {input.grad.mean().item():.8f}, "
        f"weight.grad mean: {weight.grad.mean().item():.8f}"
    )


# ---------------------------------------------------------------------------
# Edge-case tests
# ---------------------------------------------------------------------------
def test_small_inputs():
    """Test with very small inputs to catch boundary bugs."""
    for dtype in [dtypes.fp16, dtypes.bf16]:
        for m in [1, 3]:
            for n in [128, 256]:
                for head_dim in [32, 64, 128]:
                    if n % head_dim != 0:
                        continue
                    test_head_rms_norm(dtype, m, n, head_dim)


def test_large_inputs():
    """Test with larger, more realistic input sizes."""
    for dtype in [dtypes.fp16, dtypes.bf16]:
        for m in [1, 16, 128, 1024]:
            for n in [4096, 8192]:
                for head_dim in [64, 128]:
                    if n % head_dim != 0:
                        continue
                    test_head_rms_norm(dtype, m, n, head_dim)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
l_dtype = ["fp16", "bf16"]
l_m = [1, 16, 64, 128, 256, 1024]
l_n = [1024, 4096]
l_head_dim = [64, 128]

parser = argparse.ArgumentParser(
    formatter_class=argparse.RawTextHelpFormatter,
    description="Test and benchmark head_rms_norm (default: forward-only / inference)",
)
parser.add_argument("-d", "--dtype", type=str, choices=l_dtype, default=None,
                    help="Data type, e.g. -d bf16")
parser.add_argument("-m", "--m", type=int, default=None, help="M (num_tokens)")
parser.add_argument("-n", "--n", type=int, default=None, help="N (hidden_size)")
parser.add_argument("--head_dim", type=int, default=None, help="Head dimension")
parser.add_argument("--equiv", action="store_true",
                    help="Run equivalence test against standard rms_norm")
parser.add_argument("--grad", action="store_true",
                    help="Run gradient test (off by default; head_rms_norm is inference-only)")
parser.add_argument("--small", action="store_true",
                    help="Run small-input edge-case tests")
parser.add_argument("--large", action="store_true",
                    help="Run large-input benchmark tests")
parser.add_argument("--all", action="store_true",
                    help="Run all tests including --grad")

args = parser.parse_args()

if args.dtype is None:
    l_dtype = [dtypes.d_dtypes[key] for key in l_dtype]
else:
    l_dtype = [dtypes.d_dtypes[args.dtype]]
if args.m is not None:
    l_m = [args.m]
if args.n is not None:
    l_n = [args.n]
if args.head_dim is not None:
    l_head_dim = [args.head_dim]

run_all = args.all or (not args.small and not args.large and not args.equiv and not args.grad)

if run_all or args.small:
    print("\n" + "=" * 60)
    print("SMALL-INPUT EDGE-CASE TESTS")
    print("=" * 60)
    test_small_inputs()

if run_all or args.large:
    print("\n" + "=" * 60)
    print("LARGE-INPUT CORRECTNESS & PERF TESTS")
    print("=" * 60)
    for dtype in l_dtype:
        for m in l_m:
            for n in l_n:
                for hd in l_head_dim:
                    if n % hd != 0:
                        continue
                    test_head_rms_norm(dtype, m, n, hd)

if run_all or args.equiv:
    print("\n" + "=" * 60)
    print("EQUIVALENCE TEST: head_rms_norm vs rms_norm")
    print("=" * 60)
    for dtype in l_dtype:
        for m in [1, 16, 128]:
            for n in [256, 1024]:
                test_head_rms_norm_vs_global_rmsnorm(dtype, m, n, head_dim=n)

# Gradient test is opt-in: head_rms_norm has forward CUDA only (no autograd kernel).
if args.grad or args.all:
    print("\n" + "=" * 60)
    print("GRADIENT TEST (opt-in)")
    print("=" * 60)
    for dtype in [dtypes.fp16, dtypes.bf16]:
        for m in [4, 16]:
            for n in [1024]:
                for hd in [64, 128]:
                    test_head_rms_norm_gradient(dtype, m, n, hd)

print("\n✅ All head_rms_norm inference tests passed.")