test_sanity.py 25.6 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
#
# See LICENSE for license information.

import torch
import pytest

8
from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available
Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
12
13
14
15
16
17
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
18
19
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
)
from transformer_engine.common import recipe

23
# Only run FP8 tests on H100.
24
fp8_available, reason_for_no_fp8 = is_fp8_available()
25

Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values


class ModelConfig:
    def __init__(
        self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
    ):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

fp8_recipes = [
63
    None, # Handles non-FP8 case
Przemek Tredak's avatar
Przemek Tredak committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
    recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

91
92
93
param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
94
95
96

batch_sizes = [1, 2]

97
all_boolean = [True, False]
Przemek Tredak's avatar
Przemek Tredak committed
98

99
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
100
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
101
102
103
104
105
106

def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad):
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
    static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


schetlur-nv's avatar
schetlur-nv committed
155
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
156
157
158
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
    ).cuda()
159
    te_inp_hidden_states.retain_grad()
Przemek Tredak's avatar
Przemek Tredak committed
160
161
162
163
164
165
166
167
168
169
170
171
    te_inp_attn_mask = (
        torch.rand(
            (
                1,
                1,
                config.seq_len,
                config.seq_len,
            )
        )
        .cuda()
        .bool()
    )
schetlur-nv's avatar
schetlur-nv committed
172
173
174
175

    if skip_wgrad:
        _disable_wgrads(block)

176
177
178
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
179
180
181
182
183
184
            te_out = block(te_inp_hidden_states, te_inp_attn_mask)
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad):
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    te_inp_attn_mask = (
        torch.rand(
            (
                1,
                1,
                config.seq_len,
                config.seq_len,
            )
        )
        .cuda()
        .bool()
    )

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        te_out = block(te_inp_hidden_states, te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            assert (
                p.grad is None and torch.count_nonzero(p.main_grad) > 0
            ), "Gradient not accumulated."

Przemek Tredak's avatar
Przemek Tredak committed
233

schetlur-nv's avatar
schetlur-nv committed
234
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    te_inp_attn_mask = (
        torch.rand(
            (
                1,
                1,
                config.seq_len,
                config.seq_len,
            )
        )
        .cuda()
        .bool()
    )
schetlur-nv's avatar
schetlur-nv committed
250
251
252
253

    if skip_wgrad:
        _disable_wgrads(block)

254
255
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
256
257
258
259
260
261
        te_out = block(te_inp_hidden_states, te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


schetlur-nv's avatar
schetlur-nv committed
262
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    te_inp_attn_mask = (
        torch.rand(
            (
                1,
                1,
                config.seq_len,
                config.seq_len,
            )
        )
        .cuda()
        .bool()
    )
schetlur-nv's avatar
schetlur-nv committed
278
279
280
281

    if skip_wgrad:
        _disable_wgrads(block)

282
283
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
284
285
286
287
288
289
290
291
        te_out = block(
            te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


292
293
294
295
def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
296
    te_inp = torch.randn(
297
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
298
    ).cuda()
schetlur-nv's avatar
schetlur-nv committed
299
300
301
302

    if skip_wgrad:
        _disable_wgrads(block)

303
304
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
305
306
307
308
309
310
311
312
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def _test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad):
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
        config.seq_len, bs, config.hidden_size, requires_grad=True
    ).cuda()
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_normalization_amp(dtype, bs, model, skip_wgrad, skip_dgrad, normalization):
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

    block = (
        module(
            config.hidden_size,
            eps=config.eps,
        )
        .to(dtype=torch.float32)
        .cuda()
    )
    _test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad)


Przemek Tredak's avatar
Przemek Tredak committed
357
358
359
360
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
361
362
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
363
@pytest.mark.parametrize("skip_dgrad", all_boolean)
364
365
366
367
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad,
                                 zero_centered_gamma, skip_dgrad,
                                 normalization):
368
    if fp8_recipe is not None and not fp8_available:
369
        pytest.skip(reason_for_no_fp8)
370

371
372
373
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
374
375
376
377
378
379
380
381
382
383
384
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)

    block = (
        LayerNormLinear(
            config.hidden_size,
            config.hidden_size * 3,
            eps=config.eps,
            init_method=init_method,
385
            zero_centered_gamma=zero_centered_gamma,
386
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
387
388
389
390
        )
        .to(dtype=dtype)
        .cuda()
    )
391
    _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
392
393
394
395
396
397


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
398
@pytest.mark.parametrize("skip_wgrad", all_boolean)
399
400
@pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad):
401
    if fp8_recipe is not None and not fp8_available:
402
        pytest.skip(reason_for_no_fp8)
403

Przemek Tredak's avatar
Przemek Tredak committed
404
405
406
407
408
409
410
411
412
413
414
415
    config = model_configs[model]

    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        Linear(
            config.hidden_size, config.hidden_size, init_method=output_layer_init_method
        )
        .to(dtype=dtype)
        .cuda()
    )
416
    _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
417
418
419
420
421
422


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
423
424
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
425
@pytest.mark.parametrize("skip_dgrad", all_boolean)
426
@pytest.mark.parametrize("activation", all_activations)
427
428
429
430
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
                              zero_centered_gamma, skip_dgrad, activation,
                              normalization):
431
    if fp8_recipe is not None and not fp8_available:
432
        pytest.skip(reason_for_no_fp8)
433

434
435
436
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
437
438
439
440
441
442
443
444
445
446
447
448
449
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        LayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
            eps=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
450
            zero_centered_gamma=zero_centered_gamma,
451
            activation=activation,
452
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
453
454
455
456
        )
        .to(dtype=dtype)
        .cuda()
    )
457
    _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
458
459
460
461
462
463


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
464
465
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
466
@pytest.mark.parametrize("bias", all_boolean)
467
@pytest.mark.parametrize("activation", all_activations)
468
469
470
471
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
                    zero_centered_gamma, bias, activation,
                    normalization):
472
    if fp8_recipe is not None and not fp8_available:
473
        pytest.skip(reason_for_no_fp8)
474

475
476
477
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
497
            zero_centered_gamma=zero_centered_gamma,
ngoyal2707's avatar
ngoyal2707 committed
498
            bias=bias,
499
            activation=activation,
500
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
501
502
503
504
505
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
506
    _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
507
508
509
510
511
512


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
513
514
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
515
516
517
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
                     normalization):
518
    if fp8_recipe is not None and not fp8_available:
519
        pytest.skip(reason_for_no_fp8)
520

521
522
523
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=True,
            output_layernorm=True,
543
            zero_centered_gamma=zero_centered_gamma,
544
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
545
546
547
548
549
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
550
    _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
551
552
553
554
555
556


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
557
558
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
559
560
561
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
                   normalization):
562
    if fp8_recipe is not None and not fp8_available:
563
        pytest.skip(reason_for_no_fp8)
564

565
566
567
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="decoder",
588
            zero_centered_gamma=zero_centered_gamma,
589
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
590
591
592
593
594
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
595
    _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
596
597
598
599
600
601


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
602
@pytest.mark.parametrize("skip_wgrad", all_boolean)
schetlur-nv's avatar
schetlur-nv committed
603
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
604
    if fp8_recipe is not None and not fp8_available:
605
        pytest.skip(reason_for_no_fp8)
606

Przemek Tredak's avatar
Przemek Tredak committed
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
        )
        .to(dtype=torch.float32)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
629
    _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
630
631
632
633
634
635


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
636
@pytest.mark.parametrize("skip_wgrad", all_boolean)
schetlur-nv's avatar
schetlur-nv committed
637
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
638
    if fp8_recipe is not None and not fp8_available:
639
        pytest.skip(reason_for_no_fp8)
640

Przemek Tredak's avatar
Przemek Tredak committed
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            drop_path_rate=1.0,
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
666
    _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
667
668
669
670
671
672


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
673
@pytest.mark.parametrize("skip_wgrad", all_boolean)
schetlur-nv's avatar
schetlur-nv committed
674
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
675
    if fp8_recipe is not None and not fp8_available:
676
        pytest.skip(reason_for_no_fp8)
677

Przemek Tredak's avatar
Przemek Tredak committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            fuse_qkv_params=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
703
    _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
704
705
706
707
708
709
710
711
712
713


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
    if fp8_recipe is not None and not fp8_available:
714
        pytest.skip(reason_for_no_fp8)
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743

    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
            fuse_wgrad_accumulation=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

    _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad)
744
745
746
747
748
749
750
751


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
752
753
754
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
                        normalization):
755
756
757
    if fp8_recipe is not None and not fp8_available:
        pytest.skip(reason_for_no_fp8)

758
759
760
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
782
            normalization=normalization,
783
784
785
786
787
788
        )
        .to(dtype=dtype)
        .cuda()
    )

    _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad)