test_sanity.py 27.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
import torch
import pytest

12
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
13
14
15
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
16
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
20
21
22
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
23
24
    RMSNorm,
    LayerNorm,
25
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
)
from transformer_engine.common import recipe

29
# Only run FP8 tests on H100.
30
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
31

Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

51

52
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
53
class ModelConfig:
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
69
70

model_configs = {
71
72
73
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
Przemek Tredak's avatar
Przemek Tredak committed
74
75
76
}

fp8_recipes = [
77
    None, # Handles non-FP8 case
Przemek Tredak's avatar
Przemek Tredak committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
    recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

105
param_types = [torch.float32, torch.float16]
106
if is_bf16_compatible():  # bf16 requires sm_80 or higher
107
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
108

109
all_boolean = [True, False]
Przemek Tredak's avatar
Przemek Tredak committed
110

111
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
112
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
113
114
115
116
117
118

def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


119
120
121
122
123
124
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


125
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
126
127
128
129
130
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
131
132
    static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
147
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
148
149
150
151
152
153
154
155
156
157
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
158
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
159
160
161
162
163
164
165
166
167
168
169
170
171
172
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


173
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
174
    te_inp_hidden_states = torch.randn(
175
        config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
176
    ).cuda()
177
    te_inp_hidden_states.retain_grad()
178
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
schetlur-nv's avatar
schetlur-nv committed
179
180
181
182

    if skip_wgrad:
        _disable_wgrads(block)

183
184
185
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
186
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
187
188
189
190
191
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

192
193
194
195
196
197
198
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


199
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
200
    te_inp_hidden_states = torch.randn(
201
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
202
    ).cuda()
203
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
204
205
206
207
208
209
210
211
212
213
214
215

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
216
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
217
218
219
220
221
222
223
224
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
225
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
226

Przemek Tredak's avatar
Przemek Tredak committed
227

228
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
229
    te_inp_hidden_states = torch.randn(
230
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
231
    ).cuda()
232
233
234
235

    if skip_wgrad:
        _disable_wgrads(block)

236
237
238
239
240
241
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

242
    use_fp8 = fp8_recipe is not None
243
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
244
        te_out = block(te_inp_hidden_states)
245
    te_out = sync_function(te_out)
246
247
248
249
250
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


251
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
252
    te_inp_hidden_states = torch.randn(
253
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
254
255
    ).cuda()

256
    te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
257
258
259
260

    if skip_wgrad:
        _disable_wgrads(block)

261
262
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
263
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
264
265
266
267
268
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


269
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
270
    te_inp_hidden_states = torch.randn(
271
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
272
    ).cuda()
273
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
274
    enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
275
276
277
278

    if skip_wgrad:
        _disable_wgrads(block)

279
280
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
281
        te_out = block(
282
283
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
284
285
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
286
287
288
289
290
291
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


292
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
293
294
295
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
296
    te_inp = torch.randn(
297
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
298
    ).cuda()
schetlur-nv's avatar
schetlur-nv committed
299
300
301
302

    if skip_wgrad:
        _disable_wgrads(block)

303
304
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
305
306
307
308
309
310
311
312
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


313
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
314
315
316
317
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
318
        config.seq_len, config.batch_size, config.hidden_size, requires_grad=True
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    ).cuda()
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
337
@pytest.mark.parametrize("model", ["small", "weird"])
338
339
340
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
341
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
342
343
344
345
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

    block = (
346
        module(config.hidden_size)
347
348
349
        .to(dtype=torch.float32)
        .cuda()
    )
350
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
351
352


Przemek Tredak's avatar
Przemek Tredak committed
353
354
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
355
@pytest.mark.parametrize("model", ["small", "weird"])
356
357
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
358
@pytest.mark.parametrize("skip_dgrad", all_boolean)
359
@pytest.mark.parametrize("normalization", all_normalizations)
360
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
361
362
                                 zero_centered_gamma, skip_dgrad,
                                 normalization):
363
364
365
366
367
368
369
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
370

Przemek Tredak's avatar
Przemek Tredak committed
371
372
373
374
375
376
377
378
    sigma = 0.023
    init_method = init_method_normal(sigma)

    block = (
        LayerNormLinear(
            config.hidden_size,
            config.hidden_size * 3,
            init_method=init_method,
379
            zero_centered_gamma=zero_centered_gamma,
380
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
381
382
383
384
        )
        .to(dtype=dtype)
        .cuda()
    )
385
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
386
387
388
389


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
390
@pytest.mark.parametrize("model", ["small", "weird"])
391
@pytest.mark.parametrize("skip_wgrad", all_boolean)
392
@pytest.mark.parametrize("skip_dgrad", all_boolean)
393
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
394
395
    config = model_configs[model]

396
397
398
399
400
401
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
402
403
404
405
406
407
408
409
410
411
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        Linear(
            config.hidden_size, config.hidden_size, init_method=output_layer_init_method
        )
        .to(dtype=dtype)
        .cuda()
    )
412
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
413
414
415
416


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
417
@pytest.mark.parametrize("model", ["small", "weird"])
418
419
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
420
@pytest.mark.parametrize("skip_dgrad", all_boolean)
421
@pytest.mark.parametrize("activation", all_activations)
422
@pytest.mark.parametrize("normalization", all_normalizations)
423
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
424
425
                              zero_centered_gamma, skip_dgrad, activation,
                              normalization):
426
427
428
429
430
431
432
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
433

Przemek Tredak's avatar
Przemek Tredak committed
434
435
436
437
438
439
440
441
442
443
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        LayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
444
            zero_centered_gamma=zero_centered_gamma,
445
            activation=activation,
446
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
447
448
449
450
        )
        .to(dtype=dtype)
        .cuda()
    )
451
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
452
453
454
455


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
456
@pytest.mark.parametrize("model", ["small"])
457
458
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
459
@pytest.mark.parametrize("bias", all_boolean)
460
@pytest.mark.parametrize("activation", all_activations)
461
@pytest.mark.parametrize("normalization", all_normalizations)
462
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
463
@pytest.mark.parametrize("cpu_offload", all_boolean)
464
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
465
                    zero_centered_gamma, bias, activation,
466
467
                    normalization, parallel_attention_mlp,
                    cpu_offload):
468
469
470
471
472
473
474
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
475

Przemek Tredak's avatar
Przemek Tredak committed
476
477
478
479
480
481
482
483
484
485
486
487
488
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
489
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
490
491
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
492
            zero_centered_gamma=zero_centered_gamma,
ngoyal2707's avatar
ngoyal2707 committed
493
            bias=bias,
494
            activation=activation,
495
            normalization=normalization,
496
            parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
497
498
499
500
501
        )
        .to(dtype=dtype)
        .cuda()
    )

502
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
            0,
            1,
            recipe.Format.E4M3,
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
525
        cpu_offload=False,
526
    )
Przemek Tredak's avatar
Przemek Tredak committed
527
528
529
530


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
531
@pytest.mark.parametrize("model", ["small"])
532
533
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
534
@pytest.mark.parametrize("normalization", all_normalizations)
535
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
536
                     normalization):
537
538
539
540
541
542
543
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
544

Przemek Tredak's avatar
Przemek Tredak committed
545
546
547
548
549
550
551
552
553
554
555
556
557
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
558
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
559
560
            apply_residual_connection_post_layernorm=True,
            output_layernorm=True,
561
            zero_centered_gamma=zero_centered_gamma,
562
            self_attn_mask_type="padding",
563
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
564
565
566
567
568
        )
        .to(dtype=dtype)
        .cuda()
    )

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
588
589
590
591


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
592
@pytest.mark.parametrize("model", ["small"])
593
594
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
595
@pytest.mark.parametrize("normalization", all_normalizations)
596
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
597
                   normalization):
598
599
600
601
602
603
604
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
605

Przemek Tredak's avatar
Przemek Tredak committed
606
607
608
609
610
611
612
613
614
615
616
617
618
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
619
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
620
621
622
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="decoder",
623
            zero_centered_gamma=zero_centered_gamma,
624
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
625
626
627
628
629
        )
        .to(dtype=dtype)
        .cuda()
    )

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
649
650
651
652


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
653
@pytest.mark.parametrize("model", ["small"])
654
@pytest.mark.parametrize("skip_wgrad", all_boolean)
655
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
656
657
    config = model_configs[model]

658
659
660
661
662
663
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
664
665
666
667
668
669
670
671
672
673
674
675
676
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
677
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
678
679
680
681
682
        )
        .to(dtype=torch.float32)
        .cuda()
    )

683
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
684
685
686
687


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
688
@pytest.mark.parametrize("model", ["small"])
689
@pytest.mark.parametrize("skip_wgrad", all_boolean)
690
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
691
692
    config = model_configs[model]

693
694
695
696
697
698
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
699
700
701
702
703
704
705
706
707
708
709
710
711
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
712
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
713
714
715
716
717
718
719
720
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            drop_path_rate=1.0,
        )
        .to(dtype=dtype)
        .cuda()
    )

721
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
722
723
724
725


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
726
@pytest.mark.parametrize("model", ["small"])
727
@pytest.mark.parametrize("skip_wgrad", all_boolean)
728
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
729
730
    config = model_configs[model]

731
732
733
734
735
736
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
737
738
739
740
741
742
743
744
745
746
747
748
749
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
750
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
751
752
753
754
755
756
757
758
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            fuse_qkv_params=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

759
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
760
761
762
763


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
764
@pytest.mark.parametrize("model", ["small"])
765
766
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
767
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
768
769
    config = model_configs[model]

770
771
772
773
774
775
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

776
777
778
779
780
781
782
783
784
785
786
787
788
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
789
            kv_channels=config.kv_channels,
790
791
792
793
794
795
796
797
798
799
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
            fuse_wgrad_accumulation=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

800
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
801
802
803
804


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
805
@pytest.mark.parametrize("model", ["small"])
806
807
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
808
@pytest.mark.parametrize("normalization", all_normalizations)
809
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
810
                        normalization):
811
812
813
814
815
816
817
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
818
819
820
821
822
823
824
825
826
827
828
829
830
831

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
832
            kv_channels=config.kv_channels,
833
834
835
836
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
837
            normalization=normalization,
838
839
840
841
842
        )
        .to(dtype=dtype)
        .cuda()
    )

843
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
844
845
846
847
848
849
850
851
852
853
854
855
856

def test_model_multiple_cast():
    a = torch.zeros((16,16)).cuda()
    m = Linear(16,32)

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16