test_dropout_layer_norm.py 48.5 KB
Newer Older
1
2
3
4
5
6
import math

import torch
import torch.nn.functional as F
import pytest

7
from einops import rearrange, repeat
8
9

from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
10
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
Tri Dao's avatar
Tri Dao committed
11
12
from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset
13
14
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
Tri Dao's avatar
Tri Dao committed
15
16
17

try:
    from apex.normalization import FusedRMSNorm
18
    from apex.normalization.fused_layer_norm import fused_rms_norm_affine
Tri Dao's avatar
Tri Dao committed
19
except:
20
    FusedRMSNorm, fused_rms_norm_affine = None, None
21
22
23
24


is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8

Tri Dao's avatar
Tri Dao committed
25
@pytest.mark.parametrize('is_rms_norm', [False, True])
Tri Dao's avatar
Tri Dao committed
26
@pytest.mark.parametrize('has_colscale', [True, False])
Tri Dao's avatar
Tri Dao committed
27
# @pytest.mark.parametrize('has_colscale', [False])
28
29
30
31
32
33
34
35
36
37
38
39
40
@pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
41
42
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
43
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
Tri Dao's avatar
Tri Dao committed
44
                                     dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
45
46
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
Tri Dao's avatar
Tri Dao committed
47
48
49
50
51
    if is_rms_norm and FusedRMSNorm is None:
        pytest.skip()  # We need Apex's FusedRMSNorm to test
    layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
    our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
    our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
52
53
54
55
56
57
58
59
60
61
62
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
Tri Dao's avatar
Tri Dao committed
63
64
65
66
67
68
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
69
    if has_residual:
70
71
72
        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res = res_pt.detach().clone().requires_grad_()
        res_ref = res_pt.detach().clone().float().requires_grad_()
73
    else:
74
        res = None
75
76
77
78
79
80
81
82
83
84
    if has_rowscale:
        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
        survival_rate = 0.87
        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
        x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
        x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
    else:
        rowscale = None
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref
Tri Dao's avatar
Tri Dao committed
85
86
87
    if has_colscale:
        x0_scaled_pt = x0_scaled_pt * colscale_pt
        x0_scaled_ref = x0_scaled_ref * colscale_ref
Tri Dao's avatar
Tri Dao committed
88
    model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
89
    torch.nn.init.normal_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
90
91
92
93
    if not is_rms_norm:
        torch.nn.init.normal_(model_pt.bias)
    model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
    model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
94
95
96
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model_ref.weight.copy_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
97
98
99
        if not is_rms_norm:
            model.bias.copy_(model_pt.bias)
            model_ref.bias.copy_(model_pt.bias)
100
    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
101
    out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
102
                                     model.eps, rowscale=rowscale, layerscale=colscale,
Tri Dao's avatar
Tri Dao committed
103
                                     residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
104
105
106
    assert out.dtype == input_dtype
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
    if has_residual:
107
108
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
109
110
111
112
113
114
115
116
117
118
119
120
121
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
    out_ref = model_ref(residual_ref)
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    out_pt.backward(g)
    out.backward(g)
    out_ref.backward(g)
    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
    if has_residual:
122
123
        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
Tri Dao's avatar
Tri Dao committed
124
125
    if not is_rms_norm:
        assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
Tri Dao's avatar
Tri Dao committed
126
127
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150


@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    dropout_p = 0.37
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
151
152
153
    res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
    res = res_pt.detach().clone().requires_grad_()
    res_ref = res_pt.detach().clone().float().requires_grad_()
154
    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
155
156
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
157
158
159
160
161
162
163
164
165
166
    model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)
    model_pt.eval()
    model.eval()
    model_ref.eval()
167
168
169
    out = model(x0, res)
    residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
    residual_ref = x0_ref + res_ref
170
171
172
173
174
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
    out_ref = model_ref(residual_ref)
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4


Tri Dao's avatar
Tri Dao committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
193
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
Tri Dao's avatar
Tri Dao committed
194
195
                                             dropout_p, has_residual, has_rowscale, has_colscale,
                                             is_rms_norm):
196
197
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
    if is_rms_norm and FusedRMSNorm is None:
        pytest.skip()  # We need Apex's FusedRMSNorm to test
    layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
    our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
    our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
203
204
205
206
207
208
209
210
211
212
213
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 2e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
Tri Dao's avatar
Tri Dao committed
214
215
216
217
218
219
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
220
    if has_residual:
221
222
223
        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res = res_pt.detach().clone().requires_grad_()
        res_ref = res_pt.detach().clone().float().requires_grad_()
224
    else:
225
        res = None
226
227
228
229
230
231
232
233
234
235
    if has_rowscale:
        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
        survival_rate = 0.87
        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
        x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
        x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
    else:
        rowscale = None
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref
Tri Dao's avatar
Tri Dao committed
236
237
238
    if has_colscale:
        x0_scaled_pt = x0_scaled_pt * colscale_pt
        x0_scaled_ref = x0_scaled_ref * colscale_ref
Tri Dao's avatar
Tri Dao committed
239
    model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
240
    torch.nn.init.normal_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
241
242
243
244
245
    if not is_rms_norm:
        torch.nn.init.normal_(model_pt.bias)
    model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
    model = our_layer_norm_cls(hidden_size, prenorm=True, p=dropout_p, device=device,
                               dtype=weight_dtype)
246
247
248
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model_ref.weight.copy_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
249
250
251
        if not is_rms_norm:
            model.bias.copy_(model_pt.bias)
            model_ref.bias.copy_(model_pt.bias)
252
    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
253
    out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
254
                                               model.eps, rowscale=rowscale,
Tri Dao's avatar
Tri Dao committed
255
256
257
                                               layerscale=colscale, prenorm=True,
                                               residual_in_fp32=residual_in_fp32,
                                               return_dropout_mask=True)
258
259
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
    if has_residual:
260
261
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
    out_ref = model_ref(residual_ref)
    assert out.dtype == input_dtype
    assert residual.dtype == residual_dtype
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    (out_pt * F.sigmoid(residual_pt)).backward(g)
    (out * F.sigmoid(residual)).backward(g)
    (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
    if has_residual:
278
        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
279
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
Tri Dao's avatar
Tri Dao committed
280
281
    if not is_rms_norm:
        assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
Tri Dao's avatar
Tri Dao committed
282
283
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306


@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    dropout_p = 0.37
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
307
308
309
    res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
    res = res_pt.detach().clone().requires_grad_()
    res_ref = res_pt.detach().clone().float().requires_grad_()
310
    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
311
312
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
313
314
315
316
317
318
319
320
321
322
323
    model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
                                dtype=weight_dtype)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)
    model_pt.eval()
    model.eval()
    model_ref.eval()
324
325
326
    out, residual = model(x0, res)
    residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
    residual_ref = x0_ref + res_ref
327
328
329
330
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
    out_ref = model_ref(residual_ref)
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387


@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_training(
        hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
        has_residual, has_colscale):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 2e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    drop_path_rate = 0.4
    drop_path_scale = 1 / (1 - drop_path_rate)
    def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
        # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
        mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
        numrows = (mask_batch).sum().item() * seqlen
        mask_batch = mask_batch.to(device=device, non_blocking=True)
        mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
        subset = torch.cumsum(mask_batch_seqlen, dim=0,
                              dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
        return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)

    x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
                                                                   drop_path_rate, device)
    out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
                                                                      drop_path_rate, device)

    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
    if has_residual:
388
389
390
        res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
        res = res_pt.detach().clone().requires_grad_()
        res_ref = res_pt.detach().clone().float().requires_grad_()
391
    else:
392
        res = None
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

    if has_colscale:
        x0_scaled_pt = x0_pt * colscale_pt
        x0_scaled_ref = x0_ref * colscale_ref
    else:
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref

    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device,
                                dtype=weight_dtype)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)

    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    out, dmask = dropout_add_layer_norm_subset(
415
        x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
        out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
        return_dropout_mask=True)
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')

    x0_scaled_pt = x0_scaled_pt.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    x0_scaled_ref = x0_scaled_ref.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
    dmask_expanded[x0_mask_batch] = dmask
    if has_residual:
430
431
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
    out_ref = model_ref(residual_ref)[out_mask_batch]
    assert out.dtype == input_dtype
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    out_pt.backward(g)
    out.backward(g)
    out_ref.backward(g)
    assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
    if has_residual:
446
        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4


@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_prenorm_training(
        hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
        has_residual, has_colscale):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 2e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    drop_path_rate = 0.4
    drop_path_scale = 1 / (1 - drop_path_rate)
    def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
        # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
        mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
        numrows = (mask_batch).sum().item() * seqlen
        mask_batch = mask_batch.to(device=device, non_blocking=True)
        mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
        subset = torch.cumsum(mask_batch_seqlen, dim=0,
                              dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
        return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)

    x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
                                                                   drop_path_rate, device)
    out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
                                                                      drop_path_rate, device)

    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
    if has_residual:
508
509
510
        res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
        res = res_pt.detach().clone().requires_grad_()
        res_ref = res_pt.detach().clone().float().requires_grad_()
511
    else:
512
        res = None
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534

    if has_colscale:
        x0_scaled_pt = x0_pt * colscale_pt
        x0_scaled_ref = x0_ref * colscale_ref
    else:
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref

    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
                                dtype=weight_dtype)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)

    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    out, residual, dmask = dropout_add_layer_norm_subset(
535
        x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
        out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
        return_dropout_mask=True)
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')

    x0_scaled_pt = x0_scaled_pt.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    x0_scaled_ref = x0_scaled_ref.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
    dmask_expanded[x0_mask_batch] = dmask
    if has_residual:
550
551
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
    out_ref = model_ref(residual_ref)[out_mask_batch]
    assert out.dtype == input_dtype
    assert residual.dtype == residual_dtype
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g)
    (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
    (out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
    assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
    if has_residual:
568
        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
569
570
571
572
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865


@pytest.mark.parametrize('is_rms_norm', [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_training(
    hidden_size, input_dtype, residual_dtype, weight_dtype,
    dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    if is_rms_norm and fused_rms_norm_affine is None:
        pytest.skip()  # We need Apex's FusedRMSNorm to test
    our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
                           else dropout_add_rms_norm_parallel_residual)
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    if has_x1:
        x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                            requires_grad=True)
        x1 = x1_pt.detach().clone().requires_grad_()
        x1_ref = x1_pt.detach().clone().float().requires_grad_()
    else:
        x1 = None
    if has_residual:
        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res = res_pt.detach().clone().requires_grad_()
        res_ref = res_pt.detach().clone().float().requires_grad_()
    else:
        res = None
    weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
             if not is_rms_norm else None)
    weight0_pt = weight0.detach().clone().requires_grad_()
    weight0_ref = weight0.detach().clone().float().requires_grad_()
    bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
    bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
    if not tied_norm:
        weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
                 if not is_rms_norm else None)
        weight1_pt = weight1.detach().clone().requires_grad_()
        weight1_ref = weight1.detach().clone().float().requires_grad_()
        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
        bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
    else:
        weight1, bias1 = None, None
    epsilon = 1e-5
    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32

    out0, out1, dmask0, dmask1 = our_layer_norm_func(
        x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
        epsilon, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
    )
    assert out0.dtype == input_dtype
    if not tied_norm:
        assert out1.dtype == input_dtype
    print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
    if has_residual:
        if has_x1:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
                          + res_pt.float()).to(dtype=residual_dtype)
            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
                            + (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
        else:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
                          + res_pt.float()).to(dtype=residual_dtype)
            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
    else:
        if has_x1:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
                            + (x1_ref * dmask1.float()) / (1 - dropout_p))
        else:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
    if not is_rms_norm:
        out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
                               eps=epsilon).to(dtype=input_dtype)
        out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
        if not tied_norm:
            out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
                                   bias1_pt, eps=epsilon).to(dtype=input_dtype)
            out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
    else:
        out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
                                        eps=epsilon).to(dtype=input_dtype)
        out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
        if not tied_norm:
            out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
                                            (hidden_size,), eps=epsilon).to(dtype=input_dtype)
            out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)

    assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
    if not tied_norm:
        assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4

    g0 = torch.randn_like(out0) / batch_size
    if tied_norm:
        out0.backward(g0)
        out0_pt.backward(g0)
        out0_ref.backward(g0)
    else:
        g1 = torch.randn_like(out1) / batch_size
        (out0 * g0 + out1 * g1).sum().backward()
        (out0_pt * g0 + out1_pt * g1).sum().backward()
        (out0_ref * g0 + out1_ref * g1).sum().backward()
    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
    if has_x1:
        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
    if has_residual:
        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
    assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
    if not is_rms_norm:
        assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
    if not tied_norm:
        assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
        if not is_rms_norm:
            assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5


@pytest.mark.parametrize('is_rms_norm', [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_prenorm_training(
    hidden_size, input_dtype, residual_dtype, weight_dtype,
    dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    if is_rms_norm and fused_rms_norm_affine is None:
        pytest.skip()  # We need Apex's FusedRMSNorm to test
    our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
                           else dropout_add_rms_norm_parallel_residual)
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    if has_x1:
        x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                            requires_grad=True)
        x1 = x1_pt.detach().clone().requires_grad_()
        x1_ref = x1_pt.detach().clone().float().requires_grad_()
    else:
        x1 = None
    if has_residual:
        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res = res_pt.detach().clone().requires_grad_()
        res_ref = res_pt.detach().clone().float().requires_grad_()
    else:
        res = None
    weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
             if not is_rms_norm else None)
    weight0_pt = weight0.detach().clone().requires_grad_()
    weight0_ref = weight0.detach().clone().float().requires_grad_()
    bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
    bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
    if not tied_norm:
        weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
                 if not is_rms_norm else None)
        weight1_pt = weight1.detach().clone().requires_grad_()
        weight1_ref = weight1.detach().clone().float().requires_grad_()
        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
        bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
    else:
        weight1, bias1 = None, None
    epsilon = 1e-5
    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32

    out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
        x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
        epsilon, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
    )
    assert out0.dtype == input_dtype
    if not tied_norm:
        assert out1.dtype == input_dtype
    print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
    if has_residual:
        if has_x1:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
                          + res_pt.float()).to(dtype=residual_dtype)
            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
                            + (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
        else:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
                          + res_pt.float()).to(dtype=residual_dtype)
            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
    else:
        if has_x1:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
                          + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
            residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
                            + (x1_ref * dmask1.float()) / (1 - dropout_p))
        else:
            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
    if not is_rms_norm:
        out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
                               eps=epsilon).to(dtype=input_dtype)
        out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
        if not tied_norm:
            out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
                                   bias1_pt, eps=epsilon).to(dtype=input_dtype)
            out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
    else:
        out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
                                        eps=epsilon).to(dtype=input_dtype)
        out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
        if not tied_norm:
            out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
                                            (hidden_size,), eps=epsilon).to(dtype=input_dtype)
            out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)

    assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
    if not tied_norm:
        assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4

    g0 = torch.randn_like(out0) / batch_size
    if tied_norm:
        (out0 * F.sigmoid(residual)).backward(g0)
        (out0_pt * F.sigmoid(residual_pt)).backward(g0)
        (out0_ref * F.sigmoid(residual_ref)).backward(g0)
    else:
        g1 = torch.randn_like(out1) / batch_size
        (out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward()
        (out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward()
        (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
    if has_x1:
        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
    if has_residual:
        assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
    assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
    if not is_rms_norm:
        assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
    if not tied_norm:
        assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
        if not is_rms_norm:
            assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5
Tri Dao's avatar
Tri Dao committed
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891


def test_dropout_layer_norm_randomness():
    hidden_size = 256
    dtype = torch.float32
    dropout_p = 0.1
    device = 'cuda'
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x0 = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True)
    res = torch.randn_like(x0, dtype=dtype, requires_grad=True)
    model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype)
    torch.random.manual_seed(42)
    _, dmask0 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p,
                                       model.eps, return_dropout_mask=True)
    # Subsequent call should have a different dropout mask
    _, dmask1 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p,
                                       model.eps, return_dropout_mask=True)
    torch.random.manual_seed(42)
    # Resetting the seed, should get the same dropout mask
    _, dmask2 = dropout_add_layer_norm(x0, res, model.weight, model.bias, model.p,
                                       model.eps, return_dropout_mask=True)
    assert not torch.equal(dmask0, dmask1)
    assert torch.equal(dmask0, dmask2)