test_layer_norm.py 10.6 KB
Newer Older
1
# Copyright (c) 2024, Tri Dao.
2
3
4
5
6

import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
7
8
9
10
11
12
13

from flash_attn.ops.triton.layernorm import (
    layer_norm_fn,
    layer_norm_ref,
    rms_norm_ref,
    layer_norm_linear_fn,
)
14
15
16
17
18


is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8


19
20
@pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [True])
21
@pytest.mark.parametrize("dropout_p", [0.0, 0.27])
22
# @pytest.mark.parametrize("dropout_p", [0.0])
23
@pytest.mark.parametrize("prenorm", [True, False])
24
# @pytest.mark.parametrize("prenorm", [True])
25
26
27
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
28
# @pytest.mark.parametrize("has_residual", [False])
29
@pytest.mark.parametrize(
30
    "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
31
32
33
)
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
@pytest.mark.parametrize(
34
35
36
    "input_dtype,residual_dtype",
    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
37
)
38
39
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096])
40
41
# @pytest.mark.parametrize("hidden_size", [256])
def test_layer_norm(
42
43
44
45
46
47
48
49
    hidden_size,
    input_dtype,
    residual_dtype,
    weight_dtype,
    has_residual,
    is_rms_norm,
    prenorm,
    dropout_p,
50
    has_rowscale,
51
52
53
54
55
):
    device = "cuda"
    if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
        atol = 5e-2
    elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
56
        atol = 1e-2
57
58
59
60
61
62
63
64
65
    else:
        atol = 1e-4
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
    allclose = (
        lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
66
67
        # Sometimes x0_pt.grad is NaN
        <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    )
    x0 = torch.randn(
        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
    )
    x0_pt = x0.detach().clone().requires_grad_()
    x0_ref = x0.detach().clone().requires_grad_()
    if has_residual:
        res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res_pt = res.detach().clone().requires_grad_()
        res_ref = res.detach().clone().requires_grad_()
    else:
        res, res_pt, res_ref = None, None, None
    weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    if not is_rms_norm:
        bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    else:
        bias = None
    weight_pt = weight.detach().clone().requires_grad_()
    weight_ref = weight.detach().clone().requires_grad_()
    bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
    bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None

90
91
    rowscale = torch.randn(batch_size, seqlen, dtype=input_dtype, device=device) if has_rowscale else None

92
93
94
95
96
97
98
    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    out, *rest = layer_norm_fn(
        x0,
        weight,
        bias,
        residual=res,
        eps=1e-6,
99
        dropout_p=dropout_p,
100
        rowscale=rowscale,
101
102
103
        prenorm=prenorm,
        residual_in_fp32=residual_in_fp32,
        is_rms_norm=is_rms_norm,
104
        return_dropout_mask=True,
105
    )
106
107
108
109
110
111
112
113
    dropout_mask = rest[-1] if dropout_p > 0.0 else None
    out_pt = layer_norm_ref_fn(
        x0_pt,
        weight_pt,
        bias_pt,
        residual=res_pt,
        eps=1e-6,
        dropout_p=dropout_p,
114
        rowscale=rowscale,
115
116
        prenorm=prenorm,
        dropout_mask=dropout_mask,
117
    )
118
119
120
121
122
123
124
    out_ref = layer_norm_ref_fn(
        x0_ref,
        weight_ref,
        bias_ref,
        residual=res_ref,
        eps=1e-6,
        dropout_p=dropout_p,
125
        rowscale=rowscale,
126
127
128
        prenorm=prenorm,
        dropout_mask=dropout_mask,
        upcast=True,
129
    )
130
    if prenorm:
131
        residual = rest[0]
132
133
        out_pt, residual_pt = out_pt
        out_ref, residual_ref = out_ref
134
    assert out.dtype == input_dtype
135
    if prenorm:
136
137
138
        assert residual.dtype == residual_dtype
        assert allclose(residual, residual_pt, residual_ref)
    assert allclose(out, out_pt, out_ref)
139
140
141
    if dropout_mask is not None:
        dropout_fraction = 1.0 - dropout_mask.float().mean()
        assert abs(dropout_fraction - dropout_p) < 0.01
142
143

    g = torch.randn_like(out) / batch_size
144
    if not prenorm:
145
146
147
148
149
150
151
152
153
154
155
156
157
        out.backward(g)
        out_pt.backward(g)
        out_ref.backward(g)
    else:
        (out * F.sigmoid(residual)).backward(g)
        (out_pt * F.sigmoid(residual_pt)).backward(g)
        (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
    assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
    if has_residual:
        assert allclose(res.grad, res_pt.grad, res_ref.grad)
    assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
    if bias is not None:
        assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
158
159
160
161
162
163
164
165
166
167


@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [True])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize("has_residual", [False])
@pytest.mark.parametrize("weight_dtype", [torch.float32])
@pytest.mark.parametrize(
168
169
170
    "input_dtype,residual_dtype",
    [(torch.float16, torch.float16), (torch.float16, torch.float32)]
    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
# @pytest.mark.parametrize("hidden_size", [256])
def test_layer_norm_linear(
    hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
):
    device = "cuda"
    if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
        atol = 5e-2
    elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
        atol = 1e-2
    else:
        atol = 1e-4
    # set seed
    torch.random.manual_seed(0)
    batch_size = 4
    seqlen = 512
    # batch_size = 1
    # seqlen = 1
    layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
    allclose = (
        lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
        <= 2 * (x_pt - x_ref).abs().max() + atol
    )
    x0 = torch.randn(
        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
    )
    x0_pt = x0.detach().clone().requires_grad_()
    x0_ref = x0.detach().clone().requires_grad_()
    if has_residual:
        res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res_pt = res.detach().clone().requires_grad_()
        res_ref = res.detach().clone().requires_grad_()
    else:
        res, res_pt, res_ref = None, None, None
    norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    if not is_rms_norm:
        norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    else:
        norm_bias = None
    norm_weight_pt = norm_weight.detach().clone().requires_grad_()
    norm_weight_ref = norm_weight.detach().clone().requires_grad_()
    norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
    norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
    linear_weight = torch.empty(
        2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True
    )
    torch.nn.init.xavier_uniform_(linear_weight)
    if not is_rms_norm:
        linear_bias = torch.randn(
            2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True
        )
    else:
        linear_bias = None
    linear_weight_pt = linear_weight.detach().clone().requires_grad_()
    linear_weight_ref = linear_weight.detach().clone().requires_grad_()
    linear_bias_pt = (
        linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
    )
    linear_bias_ref = (
        linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
    )

    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    with torch.autocast(device_type="cuda", dtype=input_dtype):
        out, *rest = layer_norm_linear_fn(
            x0,
            norm_weight,
            norm_bias,
            linear_weight,
            linear_bias,
            residual=res,
            eps=1e-6,
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
            is_rms_norm=is_rms_norm,
        )
    out_pt, *rest_pt = layer_norm_ref_fn(
        x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
    )
    with torch.autocast(device_type="cuda", dtype=input_dtype):
        out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)
    out_ref, *rest_ref = layer_norm_ref_fn(
        x0_ref,
        norm_weight_ref,
        norm_bias_ref,
        residual=res_ref,
        eps=1e-6,
        prenorm=prenorm,
        upcast=True,
    )
    out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)
    if prenorm:
        residual = rest[0]
        residual_pt = rest_pt[0]
        residual_ref = rest_ref[0]
    assert out.dtype == input_dtype
    if prenorm:
        assert residual.dtype == residual_dtype
        assert allclose(residual, residual_pt, residual_ref)
    assert allclose(out, out_pt, out_ref)

    g = torch.randn_like(out) / batch_size
    out.backward(g)
    out_pt.backward(g)
    out_ref.backward(g)
    assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
    if has_residual:
        assert allclose(res.grad, res_pt.grad, res_ref.grad)
    assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)
    if norm_bias is not None:
        assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)
    assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)
    if linear_bias is not None:
        assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)