test_qk_norm.py 8.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from transformer_engine.pytorch import MultiheadAttention

import pytest
import torch


@pytest.mark.parametrize("use_qk_norm", [False, True])
@pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None:
    """Test QK normalization functionality, module structure, and numerical behavior."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 128

    # Create MultiheadAttention module
    mha = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        attention_type=attention_type,
        use_qk_norm=use_qk_norm,
        qk_norm_eps=qk_norm_eps,
        bias=False,
        device="cuda",
    ).cuda()

    # Check module structure based on use_qk_norm parameter
    if use_qk_norm:
        assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True"
        assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module"
        assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module"
        # Check that the module is L2Norm type
        from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization

        assert isinstance(
            mha.qk_norm, L2Normalization
        ), "qk_norm should be an L2Normalization module"
    else:
        assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False"

    # Create input tensors
    batch_size = 2  # Use a fixed batch size for testing
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    if attention_type == "cross":
        encoder_output = torch.randn(
            seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
        )
    else:
        encoder_output = None

    # Test forward pass
    with torch.no_grad():
        if attention_type == "cross":
            output = mha(hidden_states, encoder_output=encoder_output)
        else:
            output = mha(hidden_states)

    # Check output shape and numerical properties
    assert output.shape == (
        seq_len,
        batch_size,
        hidden_size,
    ), f"Output shape mismatch: {output.shape}"
    assert not torch.isnan(output).any(), "Output contains NaN"
    assert not torch.isinf(output).any(), "Output contains Inf"

    # Test with RoPE (if self-attention)
    if attention_type == "self":
        head_dim = hidden_size // num_attention_heads
        rotary_dim = head_dim // 2
        rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)

        with torch.no_grad():
            output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb)

        assert output_with_rope.shape == (
            seq_len,
            batch_size,
            hidden_size,
        ), "Output shape with RoPE mismatch"
        assert not torch.isnan(output_with_rope).any(), "RoPE output contains NaN"
        assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf"


def test_qk_norm_output_difference() -> None:
    """Test that QK normalization actually changes the output compared to no normalization."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 128
    batch_size = 2

    # Use same random seed to ensure identical weight initialization
    current_rng_state = torch.get_rng_state()
    current_cuda_rng_state = torch.cuda.get_rng_state()

    # Reset to a known seed for reproducible initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create model with QK normalization
    mha_with_norm = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        use_qk_norm=True,
        bias=False,
        device="cuda",
    ).cuda()

    # Reset to same seed for identical initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create identical model without QK normalization
    mha_no_norm = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        use_qk_norm=False,
        bias=False,
        device="cuda",
    ).cuda()

    # Create input tensors
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    # Compare outputs with identical weights but different QK norm settings
    with torch.no_grad():
        output_with_norm = mha_with_norm(hidden_states)
        output_no_norm = mha_no_norm(hidden_states)

    # Outputs should be different when QK normalization is enabled
    assert not torch.allclose(
        output_with_norm, output_no_norm, atol=1e-6
    ), "QK normalization should change the output, but outputs are identical"


def test_qk_norm_with_fused_qkv() -> None:
    """Test QK normalization works with fused QKV parameters."""
    hidden_size = 256
    num_attention_heads = 8
    seq_len = 64

    mha = MultiheadAttention(
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        fuse_qkv_params=True,
        use_qk_norm=True,
        bias=False,
        device="cuda",
    ).cuda()

    # Create input and test forward pass
    batch_size = 2  # Use a fixed batch size for testing
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    with torch.no_grad():
        output = mha(hidden_states)

    assert output.shape == (
        seq_len,
        batch_size,
        hidden_size,
    ), f"Output shape mismatch: {output.shape}"


def test_qk_norm_transformer_layer_output_difference() -> None:
    """Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
    from transformer_engine.pytorch import TransformerLayer

    hidden_size = 256
    ffn_hidden_size = 1024
    num_attention_heads = 8
    seq_len = 128
    batch_size = 2

    # Use same random seed to ensure identical weight initialization
    current_rng_state = torch.get_rng_state()
    current_cuda_rng_state = torch.cuda.get_rng_state()

    # Reset to a known seed for reproducible initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create TransformerLayer with QK normalization
    transformer_with_norm = TransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
        use_qk_norm=True,
        bias=False,
        device="cuda",
    ).cuda()

    # Reset to same seed for identical initialization
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Create identical TransformerLayer without QK normalization
    transformer_no_norm = TransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
        use_qk_norm=False,
        bias=False,
        device="cuda",
    ).cuda()

    # Create input tensors
    hidden_states = torch.randn(
        seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
    )

    # Compare outputs with identical weights but different QK norm settings
    with torch.no_grad():
        output_with_norm = transformer_with_norm(hidden_states)
        output_no_norm = transformer_no_norm(hidden_states)

    # Outputs should be different when QK normalization is enabled
    assert not torch.allclose(
        output_with_norm, output_no_norm, atol=1e-6
    ), "QK normalization should change the TransformerLayer output, but outputs are identical"

    # Check that outputs have expected shapes and properties
    assert output_with_norm.shape == (
        seq_len,
        batch_size,
        hidden_size,
    ), f"Output shape mismatch: {output_with_norm.shape}"
    assert not torch.isnan(output_with_norm).any(), "Output with QK norm contains NaN"
    assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf"
    assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN"
    assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf"