test_sanity.py 28.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
import torch
import pytest

12
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
13
14
15
16
17
18
19
20
21
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
22
23
    RMSNorm,
    LayerNorm,
24
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
25
26
27
)
from transformer_engine.common import recipe

28
# Only run FP8 tests on H100.
29
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
30

Przemek Tredak's avatar
Przemek Tredak committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

50
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
51
class ModelConfig:
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
67
68

model_configs = {
69
70
71
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
Przemek Tredak's avatar
Przemek Tredak committed
72
73
74
}

fp8_recipes = [
75
    None, # Handles non-FP8 case
Przemek Tredak's avatar
Przemek Tredak committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
    recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

103
104
105
param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
106

107
all_boolean = [True, False]
Przemek Tredak's avatar
Przemek Tredak committed
108

109
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
110
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
111
112
113
114
115
116

def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


117
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
118
119
120
121
122
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
123
124
    static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


165
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
166
    te_inp_hidden_states = torch.randn(
167
        config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
168
    ).cuda()
169
    te_inp_hidden_states.retain_grad()
170
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
schetlur-nv's avatar
schetlur-nv committed
171
172
173
174

    if skip_wgrad:
        _disable_wgrads(block)

175
176
177
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
178
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
179
180
181
182
183
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

184
185
186
187
188
189
190
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


191
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
192
    te_inp_hidden_states = torch.randn(
193
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
194
    ).cuda()
195
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
196
197
198
199
200
201
202
203
204
205
206
207

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
208
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
209
210
211
212
213
214
215
216
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
217
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
218

Przemek Tredak's avatar
Przemek Tredak committed
219

220
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
221
    te_inp_hidden_states = torch.randn(
222
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
223
    ).cuda()
224
225
226
227

    if skip_wgrad:
        _disable_wgrads(block)

228
229
230
231
232
233
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

234
    use_fp8 = fp8_recipe is not None
235
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
236
        te_out = block(te_inp_hidden_states)
237
    te_out = sync_function(te_out)
238
239
240
241
242
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


243
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
244
    te_inp_hidden_states = torch.randn(
245
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
246
247
    ).cuda()

248
    te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
249
250
251
252

    if skip_wgrad:
        _disable_wgrads(block)

253
254
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
255
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
256
257
258
259
260
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


261
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
262
    te_inp_hidden_states = torch.randn(
263
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
264
    ).cuda()
265
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
266
    enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
267
268
269
270

    if skip_wgrad:
        _disable_wgrads(block)

271
272
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
273
        te_out = block(
274
275
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
276
277
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
278
279
280
281
282
283
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


284
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
285
286
287
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
288
    te_inp = torch.randn(
289
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
290
    ).cuda()
schetlur-nv's avatar
schetlur-nv committed
291
292
293
294

    if skip_wgrad:
        _disable_wgrads(block)

295
296
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
297
298
299
300
301
302
303
304
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


305
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
306
307
308
309
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
310
        config.seq_len, config.batch_size, config.hidden_size, requires_grad=True
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    ).cuda()
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
329
@pytest.mark.parametrize("model", ["small", "weird"])
330
331
332
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
333
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
334
335
336
337
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

    block = (
338
        module(config.hidden_size)
339
340
341
        .to(dtype=torch.float32)
        .cuda()
    )
342
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
343
344


Przemek Tredak's avatar
Przemek Tredak committed
345
346
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
347
@pytest.mark.parametrize("model", ["small", "weird"])
348
349
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
350
@pytest.mark.parametrize("skip_dgrad", all_boolean)
351
@pytest.mark.parametrize("normalization", all_normalizations)
352
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
353
354
                                 zero_centered_gamma, skip_dgrad,
                                 normalization):
355
356
357
358
359
360
361
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
362

363
364
365
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
366
367
368
369
370
371
372
373
    sigma = 0.023
    init_method = init_method_normal(sigma)

    block = (
        LayerNormLinear(
            config.hidden_size,
            config.hidden_size * 3,
            init_method=init_method,
374
            zero_centered_gamma=zero_centered_gamma,
375
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
376
377
378
379
        )
        .to(dtype=dtype)
        .cuda()
    )
380
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
381
382
383
384


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
385
@pytest.mark.parametrize("model", ["small", "weird"])
386
@pytest.mark.parametrize("skip_wgrad", all_boolean)
387
@pytest.mark.parametrize("skip_dgrad", all_boolean)
388
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
389
390
    config = model_configs[model]

391
392
393
394
395
396
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
397
398
399
400
401
402
403
404
405
406
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        Linear(
            config.hidden_size, config.hidden_size, init_method=output_layer_init_method
        )
        .to(dtype=dtype)
        .cuda()
    )
407
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
408
409
410
411


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
412
@pytest.mark.parametrize("model", ["small", "weird"])
413
414
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
415
@pytest.mark.parametrize("skip_dgrad", all_boolean)
416
@pytest.mark.parametrize("activation", all_activations)
417
@pytest.mark.parametrize("normalization", all_normalizations)
418
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
419
420
                              zero_centered_gamma, skip_dgrad, activation,
                              normalization):
421
422
423
424
425
426
427
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
428

429
430
431
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
432
433
434
435
436
437
438
439
440
441
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        LayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
442
            zero_centered_gamma=zero_centered_gamma,
443
            activation=activation,
444
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
445
446
447
448
        )
        .to(dtype=dtype)
        .cuda()
    )
449
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
450
451
452
453


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
454
@pytest.mark.parametrize("model", ["small"])
455
456
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
457
@pytest.mark.parametrize("bias", all_boolean)
458
@pytest.mark.parametrize("activation", all_activations)
459
@pytest.mark.parametrize("normalization", all_normalizations)
460
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
461
@pytest.mark.parametrize("cpu_offload", all_boolean)
462
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
463
                    zero_centered_gamma, bias, activation,
464
465
                    normalization, parallel_attention_mlp,
                    cpu_offload):
466
467
468
469
470
471
472
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
473

474
475
476
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
477
478
479
480
481
482
483
484
485
486
487
488
489
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
490
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
491
492
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
493
            zero_centered_gamma=zero_centered_gamma,
ngoyal2707's avatar
ngoyal2707 committed
494
            bias=bias,
495
            activation=activation,
496
            normalization=normalization,
497
            parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
498
499
500
501
502
        )
        .to(dtype=dtype)
        .cuda()
    )

503
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
            0,
            1,
            recipe.Format.E4M3,
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
526
        cpu_offload=False,
527
    )
Przemek Tredak's avatar
Przemek Tredak committed
528
529
530
531


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
532
@pytest.mark.parametrize("model", ["small"])
533
534
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
535
@pytest.mark.parametrize("normalization", all_normalizations)
536
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
537
                     normalization):
538
539
540
541
542
543
544
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
545

546
547
548
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
549
550
551
552
553
554
555
556
557
558
559
560
561
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
562
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
563
564
            apply_residual_connection_post_layernorm=True,
            output_layernorm=True,
565
            zero_centered_gamma=zero_centered_gamma,
566
            self_attn_mask_type="padding",
567
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
568
569
570
571
572
        )
        .to(dtype=dtype)
        .cuda()
    )

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
592
593
594
595


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
596
@pytest.mark.parametrize("model", ["small"])
597
598
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
599
@pytest.mark.parametrize("normalization", all_normalizations)
600
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
601
                   normalization):
602
603
604
605
606
607
608
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
609

610
611
612
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
613
614
615
616
617
618
619
620
621
622
623
624
625
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
626
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
627
628
629
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="decoder",
630
            zero_centered_gamma=zero_centered_gamma,
631
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
632
633
634
635
636
        )
        .to(dtype=dtype)
        .cuda()
    )

637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
656
657
658
659


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
660
@pytest.mark.parametrize("model", ["small"])
661
@pytest.mark.parametrize("skip_wgrad", all_boolean)
662
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
663
664
    config = model_configs[model]

665
666
667
668
669
670
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
671
672
673
674
675
676
677
678
679
680
681
682
683
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
684
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
685
686
687
688
689
        )
        .to(dtype=torch.float32)
        .cuda()
    )

690
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
691
692
693
694


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
695
@pytest.mark.parametrize("model", ["small"])
696
@pytest.mark.parametrize("skip_wgrad", all_boolean)
697
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
698
699
    config = model_configs[model]

700
701
702
703
704
705
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
706
707
708
709
710
711
712
713
714
715
716
717
718
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
719
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
720
721
722
723
724
725
726
727
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            drop_path_rate=1.0,
        )
        .to(dtype=dtype)
        .cuda()
    )

728
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
729
730
731
732


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
733
@pytest.mark.parametrize("model", ["small"])
734
@pytest.mark.parametrize("skip_wgrad", all_boolean)
735
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
736
737
    config = model_configs[model]

738
739
740
741
742
743
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
744
745
746
747
748
749
750
751
752
753
754
755
756
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
757
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
758
759
760
761
762
763
764
765
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            fuse_qkv_params=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

766
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
767
768
769
770


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
771
@pytest.mark.parametrize("model", ["small"])
772
773
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
774
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
775
776
    config = model_configs[model]

777
778
779
780
781
782
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

783
784
785
786
787
788
789
790
791
792
793
794
795
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
796
            kv_channels=config.kv_channels,
797
798
799
800
801
802
803
804
805
806
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
            fuse_wgrad_accumulation=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

807
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
808
809
810
811


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
812
@pytest.mark.parametrize("model", ["small"])
813
814
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
815
@pytest.mark.parametrize("normalization", all_normalizations)
816
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
817
                        normalization):
818
819
820
821
822
823
824
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
825

826
827
828
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

829
830
831
832
833
834
835
836
837
838
839
840
841
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
842
            kv_channels=config.kv_channels,
843
844
845
846
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
847
            normalization=normalization,
848
849
850
851
852
        )
        .to(dtype=dtype)
        .cuda()
    )

853
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
854
855
856
857
858
859
860
861
862
863
864
865
866

def test_model_multiple_cast():
    a = torch.zeros((16,16)).cuda()
    m = Linear(16,32)

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16