test_qk_norm.py 14.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from transformer_engine.pytorch import MultiheadAttention

import pytest
import torch


11
@pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"])
12
13
@pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
14
def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None:
15
16
17
18
19
20
21
22
23
24
    """Test QK normalization functionality, module structure, and numerical behavior."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 128

    # Create MultiheadAttention module
    mha = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        attention_type=attention_type,
25
        qk_norm_type=qk_norm_type,
26
27
28
29
30
        qk_norm_eps=qk_norm_eps,
        bias=False,
        device="cuda",
    ).cuda()

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # Check module structure based on qk_norm_type parameter
    if qk_norm_type is not None:
        assert mha.q_norm is not None, "Should have q_norm module when qk_norm_type is not None"
        assert mha.k_norm is not None, "Should have k_norm module when qk_norm_type is not None"

        # Check that the modules are of the correct type
        if qk_norm_type == "L2Normalization":
            from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization

            assert isinstance(
                mha.q_norm, L2Normalization
            ), "q_norm should be an L2Normalization module"
            assert isinstance(
                mha.k_norm, L2Normalization
            ), "k_norm should be an L2Normalization module"
            # For L2 normalization, q_norm and k_norm should be the same instance (parameter-free)
            assert (
                mha.q_norm is mha.k_norm
            ), "q_norm and k_norm should be the same instance for L2 normalization"

        elif qk_norm_type == "RMSNorm":
            from transformer_engine.pytorch.module.rmsnorm import RMSNorm

            assert isinstance(mha.q_norm, RMSNorm), "q_norm should be an RMSNorm module"
            assert isinstance(mha.k_norm, RMSNorm), "k_norm should be an RMSNorm module"
            # For RMS normalization, q_norm and k_norm should be separate instances
            assert (
                mha.q_norm is not mha.k_norm
            ), "q_norm and k_norm should be separate instances for RMS normalization"

        elif qk_norm_type == "LayerNorm":
            from transformer_engine.pytorch.module.layernorm import LayerNorm

            assert isinstance(mha.q_norm, LayerNorm), "q_norm should be a LayerNorm module"
            assert isinstance(mha.k_norm, LayerNorm), "k_norm should be a LayerNorm module"
            # For LayerNorm, q_norm and k_norm should be separate instances
            assert (
                mha.q_norm is not mha.k_norm
            ), "q_norm and k_norm should be separate instances for LayerNorm"

        else:
            # For extensibility - just ensure they exist
            assert mha.q_norm is not None, f"q_norm should exist for qk_norm_type={qk_norm_type}"
            assert mha.k_norm is not None, f"k_norm should exist for qk_norm_type={qk_norm_type}"
75
    else:
76
77
        assert mha.q_norm is None, "Should not have q_norm module when qk_norm_type is None"
        assert mha.k_norm is None, "Should not have k_norm module when qk_norm_type is None"
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    # Create input tensors
    batch_size = 2  # Use a fixed batch size for testing
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    if attention_type == "cross":
        encoder_output = torch.randn(
            seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
        )
    else:
        encoder_output = None

    # Test forward pass
    with torch.no_grad():
        if attention_type == "cross":
            output = mha(hidden_states, encoder_output=encoder_output)
        else:
            output = mha(hidden_states)

    # Check output shape and numerical properties
    assert output.shape == (
        seq_len,
        batch_size,
        hidden_size,
    ), f"Output shape mismatch: {output.shape}"
    assert not torch.isnan(output).any(), "Output contains NaN"
    assert not torch.isinf(output).any(), "Output contains Inf"

    # Test with RoPE (if self-attention)
    if attention_type == "self":
        head_dim = hidden_size // num_attention_heads
        rotary_dim = head_dim // 2
        rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)

        with torch.no_grad():
            output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb)

        assert output_with_rope.shape == (
            seq_len,
            batch_size,
            hidden_size,
        ), "Output shape with RoPE mismatch"
        assert not torch.isnan(output_with_rope).any(), "RoPE output contains NaN"
        assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf"


126
127
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_output_difference(qk_norm_type) -> None:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    """Test that QK normalization actually changes the output compared to no normalization."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 128
    batch_size = 2

    # Reset to a known seed for reproducible initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create model with QK normalization
    mha_with_norm = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
142
        qk_norm_type=qk_norm_type,
143
144
145
146
147
148
149
150
151
152
153
154
        bias=False,
        device="cuda",
    ).cuda()

    # Reset to same seed for identical initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create identical model without QK normalization
    mha_no_norm = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
155
        qk_norm_type=None,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        bias=False,
        device="cuda",
    ).cuda()

    # Create input tensors
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    # Compare outputs with identical weights but different QK norm settings
    with torch.no_grad():
        output_with_norm = mha_with_norm(hidden_states)
        output_no_norm = mha_no_norm(hidden_states)

    # Outputs should be different when QK normalization is enabled
    assert not torch.allclose(
        output_with_norm, output_no_norm, atol=1e-6
173
    ), f"QK normalization ({qk_norm_type}) should change the output, but outputs are identical"
174
175


176
177
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_with_fused_qkv(qk_norm_type) -> None:
178
179
180
181
182
183
184
185
186
    """Test QK normalization works with fused QKV parameters."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 64

    mha = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        fuse_qkv_params=True,
187
        qk_norm_type=qk_norm_type,
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        bias=False,
        device="cuda",
    ).cuda()

    # Create input and test forward pass
    batch_size = 2  # Use a fixed batch size for testing
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    with torch.no_grad():
        output = mha(hidden_states)

    assert output.shape == (
        seq_len,
        batch_size,
        hidden_size,
    ), f"Output shape mismatch: {output.shape}"


208
209
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_transformer_layer_output_difference(qk_norm_type) -> None:
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    """Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
    from transformer_engine.pytorch import TransformerLayer

    hidden_size = 256
    ffn_hidden_size = 1024
    num_attention_heads = 8
    seq_len = 128
    batch_size = 2

    # Reset to a known seed for reproducible initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create TransformerLayer with QK normalization
    transformer_with_norm = TransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
228
        qk_norm_type=qk_norm_type,
229
230
231
232
233
234
235
236
237
238
239
240
241
        bias=False,
        device="cuda",
    ).cuda()

    # Reset to same seed for identical initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create identical TransformerLayer without QK normalization
    transformer_no_norm = TransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
242
        qk_norm_type=None,
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        bias=False,
        device="cuda",
    ).cuda()

    # Create input tensors
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    # Compare outputs with identical weights but different QK norm settings
    with torch.no_grad():
        output_with_norm = transformer_with_norm(hidden_states)
        output_no_norm = transformer_no_norm(hidden_states)

    # Outputs should be different when QK normalization is enabled
258
259
260
261
    assert not torch.allclose(output_with_norm, output_no_norm, atol=1e-6), (
        f"QK normalization ({qk_norm_type}) should change the TransformerLayer output, but outputs"
        " are identical"
    )
262
263
264
265
266
267
268
269
270
271
272

    # Check that outputs have expected shapes and properties
    assert output_with_norm.shape == (
        seq_len,
        batch_size,
        hidden_size,
    ), f"Output shape mismatch: {output_with_norm.shape}"
    assert not torch.isnan(output_with_norm).any(), "Output with QK norm contains NaN"
    assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf"
    assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN"
    assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf"
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389


@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_before_after_rope(qk_norm_type) -> None:
    """Test that QK normalization before and after RoPE works without errors."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 64
    batch_size = 2

    # Create model with QK norm after RoPE (default)
    mha_after = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        qk_norm_type=qk_norm_type,
        qk_norm_before_rope=False,
        bias=False,
        device="cuda",
    ).cuda()

    # Create model with QK norm before RoPE
    mha_before = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        qk_norm_type=qk_norm_type,
        qk_norm_before_rope=True,
        bias=False,
        device="cuda",
    ).cuda()

    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    # Create RoPE embeddings
    head_dim = hidden_size // num_attention_heads
    rotary_dim = head_dim // 2
    rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)

    with torch.no_grad():
        output_after_rope = mha_after(hidden_states, rotary_pos_emb=rotary_pos_emb)
        output_before_rope = mha_before(hidden_states, rotary_pos_emb=rotary_pos_emb)

        output_after_no_rope = mha_after(hidden_states)
        output_before_no_rope = mha_before(hidden_states)

    # Check output shapes and properties
    expected_shape = (seq_len, batch_size, hidden_size)
    for output in [
        output_after_rope,
        output_before_rope,
        output_after_no_rope,
        output_before_no_rope,
    ]:
        assert output.shape == expected_shape, f"Output shape mismatch: {output.shape}"
        assert not torch.isnan(output).any(), "Output contains NaN"
        assert not torch.isinf(output).any(), "Output contains Inf"

    assert output_after_rope.shape == output_before_rope.shape, "Outputs should have same shape"
    assert mha_after.qk_norm_before_rope == False, "mha_after should have qk_norm_before_rope=False"
    assert mha_before.qk_norm_before_rope == True, "mha_before should have qk_norm_before_rope=True"


def test_different_qk_norm_types_produce_different_outputs() -> None:
    """Test that different QK normalization types produce different outputs."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 128
    batch_size = 2

    # Use same random seed to ensure identical weight initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create model with L2 normalization
    mha_l2 = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        qk_norm_type="L2Normalization",
        bias=False,
        device="cuda",
    ).cuda()

    # Reset to same seed for identical initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create model with RMS normalization
    mha_rms = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        qk_norm_type="RMSNorm",
        bias=False,
        device="cuda",
    ).cuda()

    # Create input tensors
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    # Compare outputs with identical weights but different QK norm types
    with torch.no_grad():
        output_l2 = mha_l2(hidden_states)
        output_rms = mha_rms(hidden_states)

    # Outputs should be different when using different normalization types
    assert not torch.allclose(
        output_l2, output_rms, atol=1e-6
    ), "L2 and RMS normalization should produce different outputs, but outputs are identical"

    # Check that outputs have expected shapes and properties
    assert output_l2.shape == output_rms.shape, "L2 and RMS outputs should have same shape"
    assert not torch.isnan(output_l2).any(), "L2 output contains NaN"
    assert not torch.isinf(output_l2).any(), "L2 output contains Inf"
    assert not torch.isnan(output_rms).any(), "RMS output contains NaN"
    assert not torch.isinf(output_rms).any(), "RMS output contains Inf"