test_dropout_layer_norm.py 30.4 KB
Newer Older
1
2
3
4
5
6
import math

import torch
import torch.nn.functional as F
import pytest

7
from einops import rearrange, repeat
8
9

from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
10
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
Tri Dao's avatar
Tri Dao committed
11
12
13
14
15
16
17
from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset

try:
    from apex.normalization import FusedRMSNorm
except:
    FusedRMSNorm = None
18
19
20
21


is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8

Tri Dao's avatar
Tri Dao committed
22
@pytest.mark.parametrize('is_rms_norm', [False, True])
Tri Dao's avatar
Tri Dao committed
23
@pytest.mark.parametrize('has_colscale', [True, False])
Tri Dao's avatar
Tri Dao committed
24
# @pytest.mark.parametrize('has_colscale', [False])
25
26
27
28
29
30
31
32
33
34
35
36
37
@pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
Tri Dao's avatar
Tri Dao committed
38
39
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize('hidden_size', [256])
40
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
Tri Dao's avatar
Tri Dao committed
41
                                     dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
42
43
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
    if is_rms_norm and FusedRMSNorm is None:
        pytest.skip()  # We need Apex's FusedRMSNorm to test
    layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
    our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
    our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
49
50
51
52
53
54
55
56
57
58
59
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
Tri Dao's avatar
Tri Dao committed
60
61
62
63
64
65
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    if has_residual:
        x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        x1 = x1_pt.detach().clone().requires_grad_()
        x1_ref = x1_pt.detach().clone().float().requires_grad_()
    else:
        x1 = None
    if has_rowscale:
        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
        survival_rate = 0.87
        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
        x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
        x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
    else:
        rowscale = None
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref
Tri Dao's avatar
Tri Dao committed
82
83
84
    if has_colscale:
        x0_scaled_pt = x0_scaled_pt * colscale_pt
        x0_scaled_ref = x0_scaled_ref * colscale_ref
Tri Dao's avatar
Tri Dao committed
85
    model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
86
    torch.nn.init.normal_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
87
88
89
90
    if not is_rms_norm:
        torch.nn.init.normal_(model_pt.bias)
    model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
    model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
91
92
93
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model_ref.weight.copy_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
94
95
96
        if not is_rms_norm:
            model.bias.copy_(model_pt.bias)
            model_ref.bias.copy_(model_pt.bias)
97
    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
Tri Dao's avatar
Tri Dao committed
98
99
100
    out, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
                                     model.epsilon, rowscale=rowscale, layerscale=colscale,
                                     residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    assert out.dtype == input_dtype
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
    if has_residual:
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
    out_ref = model_ref(residual_ref)
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    out_pt.backward(g)
    out.backward(g)
    out_ref.backward(g)
    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
    if has_residual:
        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
Tri Dao's avatar
Tri Dao committed
121
122
    if not is_rms_norm:
        assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
Tri Dao's avatar
Tri Dao committed
123
124
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151


@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    dropout_p = 0.37
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
    x1 = x1_pt.detach().clone().requires_grad_()
    x1_ref = x1_pt.detach().clone().float().requires_grad_()
    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
152
153
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)
    model_pt.eval()
    model.eval()
    model_ref.eval()
    out = model(x0, x1)
    residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
    residual_ref = x0_ref + x1_ref
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
    out_ref = model_ref(residual_ref)
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4


Tri Dao's avatar
Tri Dao committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
190
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
Tri Dao's avatar
Tri Dao committed
191
192
                                             dropout_p, has_residual, has_rowscale, has_colscale,
                                             is_rms_norm):
193
194
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
Tri Dao's avatar
Tri Dao committed
195
196
197
198
199
    if is_rms_norm and FusedRMSNorm is None:
        pytest.skip()  # We need Apex's FusedRMSNorm to test
    layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
    our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
    our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
200
201
202
203
204
205
206
207
208
209
210
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 2e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
Tri Dao's avatar
Tri Dao committed
211
212
213
214
215
216
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    if has_residual:
        x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        x1 = x1_pt.detach().clone().requires_grad_()
        x1_ref = x1_pt.detach().clone().float().requires_grad_()
    else:
        x1 = None
    if has_rowscale:
        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
        survival_rate = 0.87
        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
        x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
        x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
    else:
        rowscale = None
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref
Tri Dao's avatar
Tri Dao committed
233
234
235
    if has_colscale:
        x0_scaled_pt = x0_scaled_pt * colscale_pt
        x0_scaled_ref = x0_scaled_ref * colscale_ref
Tri Dao's avatar
Tri Dao committed
236
    model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
237
    torch.nn.init.normal_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
238
239
240
241
242
    if not is_rms_norm:
        torch.nn.init.normal_(model_pt.bias)
    model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
    model = our_layer_norm_cls(hidden_size, prenorm=True, p=dropout_p, device=device,
                               dtype=weight_dtype)
243
244
245
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model_ref.weight.copy_(model_pt.weight)
Tri Dao's avatar
Tri Dao committed
246
247
248
        if not is_rms_norm:
            model.bias.copy_(model_pt.bias)
            model_ref.bias.copy_(model_pt.bias)
249
    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
Tri Dao's avatar
Tri Dao committed
250
251
252
253
254
    out, residual, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
                                               model.epsilon, rowscale=rowscale,
                                               layerscale=colscale, prenorm=True,
                                               residual_in_fp32=residual_in_fp32,
                                               return_dropout_mask=True)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
    if has_residual:
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
    out_ref = model_ref(residual_ref)
    assert out.dtype == input_dtype
    assert residual.dtype == residual_dtype
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    (out_pt * F.sigmoid(residual_pt)).backward(g)
    (out * F.sigmoid(residual)).backward(g)
    (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
    if has_residual:
        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
Tri Dao's avatar
Tri Dao committed
277
278
    if not is_rms_norm:
        assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
Tri Dao's avatar
Tri Dao committed
279
280
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307


@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 1e-4)
    dropout_p = 0.37
    # set seed
    torch.random.manual_seed(0)
    batch_size = 32
    seqlen = 512
    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone().requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
    x1 = x1_pt.detach().clone().requires_grad_()
    x1_ref = x1_pt.detach().clone().float().requires_grad_()
    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
308
309
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
                                dtype=weight_dtype)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)
    model_pt.eval()
    model.eval()
    model_ref.eval()
    out, residual = model(x0, x1)
    residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
    residual_ref = x0_ref + x1_ref
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
    out_ref = model_ref(residual_ref)
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569


@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_training(
        hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
        has_residual, has_colscale):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 2e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    drop_path_rate = 0.4
    drop_path_scale = 1 / (1 - drop_path_rate)
    def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
        # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
        mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
        numrows = (mask_batch).sum().item() * seqlen
        mask_batch = mask_batch.to(device=device, non_blocking=True)
        mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
        subset = torch.cumsum(mask_batch_seqlen, dim=0,
                              dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
        return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)

    x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
                                                                   drop_path_rate, device)
    out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
                                                                      drop_path_rate, device)

    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
    if has_residual:
        x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
        x1 = x1_pt.detach().clone().requires_grad_()
        x1_ref = x1_pt.detach().clone().float().requires_grad_()
    else:
        x1 = None

    if has_colscale:
        x0_scaled_pt = x0_pt * colscale_pt
        x0_scaled_ref = x0_ref * colscale_ref
    else:
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref

    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device,
                                dtype=weight_dtype)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)

    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    out, dmask = dropout_add_layer_norm_subset(
        x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
        x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
        out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
        return_dropout_mask=True)
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')

    x0_scaled_pt = x0_scaled_pt.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    x0_scaled_ref = x0_scaled_ref.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
    dmask_expanded[x0_mask_batch] = dmask
    if has_residual:
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
    out_ref = model_ref(residual_ref)[out_mask_batch]
    assert out.dtype == input_dtype
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    out_pt.backward(g)
    out.backward(g)
    out_ref.backward(g)
    assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
    if has_residual:
        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4


@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
                         [(torch.float16, torch.float16), (torch.float16, torch.float32),
                          (torch.float32, torch.float32)]
                         + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_prenorm_training(
        hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
        has_residual, has_colscale):
    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
        pytest.skip()  # Not supported
    device = 'cuda'
    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
    rtol, atol = (1e-3, 2e-4)
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    drop_path_rate = 0.4
    drop_path_scale = 1 / (1 - drop_path_rate)
    def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
        # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
        mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
        numrows = (mask_batch).sum().item() * seqlen
        mask_batch = mask_batch.to(device=device, non_blocking=True)
        mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
        subset = torch.cumsum(mask_batch_seqlen, dim=0,
                              dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
        return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)

    x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
                                                                   drop_path_rate, device)
    out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
                                                                      drop_path_rate, device)

    x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
                        requires_grad=True)
    x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
    x0_ref = x0_pt.detach().clone().float().requires_grad_()
    if has_colscale:
        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
        colscale_pt = colscale.detach().clone().requires_grad_()
        colscale_ref = colscale.detach().clone().float().requires_grad_()
    else:
        colscale = None
    if has_residual:
        x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
        x1 = x1_pt.detach().clone().requires_grad_()
        x1_ref = x1_pt.detach().clone().float().requires_grad_()
    else:
        x1 = None

    if has_colscale:
        x0_scaled_pt = x0_pt * colscale_pt
        x0_scaled_ref = x0_ref * colscale_ref
    else:
        x0_scaled_pt = x0_pt
        x0_scaled_ref = x0_ref

    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
    torch.nn.init.normal_(model_pt.weight)
    torch.nn.init.normal_(model_pt.bias)
    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
    model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
                                dtype=weight_dtype)
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model.bias.copy_(model_pt.bias)
        model_ref.weight.copy_(model_pt.weight)
        model_ref.bias.copy_(model_pt.bias)

    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    out, residual, dmask = dropout_add_layer_norm_subset(
        x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
        x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
        out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
        return_dropout_mask=True)
    print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')

    x0_scaled_pt = x0_scaled_pt.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    x0_scaled_ref = x0_scaled_ref.masked_fill(
        repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
    ) * drop_path_scale
    dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
    dmask_expanded[x0_mask_batch] = dmask
    if has_residual:
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
    else:
        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
    out_ref = model_ref(residual_ref)[out_mask_batch]
    assert out.dtype == input_dtype
    assert residual.dtype == residual_dtype
    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
    assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4

    g = torch.randn_like(out) / batch_size
    (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g)
    (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
    (out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
    assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
    if has_residual:
        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
    if has_colscale:
        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4