test_sanity.py 27.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
import torch
import pytest

12
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
13
14
15
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
16
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
20
21
22
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
23
24
    RMSNorm,
    LayerNorm,
25
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
)
from transformer_engine.common import recipe

29
# Only run FP8 tests on H100.
30
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
31

Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

51
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
52
class ModelConfig:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
68
69

model_configs = {
70
71
72
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
Przemek Tredak's avatar
Przemek Tredak committed
73
74
75
}

fp8_recipes = [
76
    None, # Handles non-FP8 case
Przemek Tredak's avatar
Przemek Tredak committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
    recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

104
param_types = [torch.float32, torch.float16]
105
if is_bf16_compatible():  # bf16 requires sm_80 or higher
106
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
107

108
all_boolean = [True, False]
Przemek Tredak's avatar
Przemek Tredak committed
109

110
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
111
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
112
113
114
115
116
117

def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


118
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
119
120
121
122
123
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
124
125
    static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


166
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
167
    te_inp_hidden_states = torch.randn(
168
        config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
169
    ).cuda()
170
    te_inp_hidden_states.retain_grad()
171
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
schetlur-nv's avatar
schetlur-nv committed
172
173
174
175

    if skip_wgrad:
        _disable_wgrads(block)

176
177
178
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
179
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
180
181
182
183
184
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

185
186
187
188
189
190
191
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


192
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
193
    te_inp_hidden_states = torch.randn(
194
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
195
    ).cuda()
196
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
197
198
199
200
201
202
203
204
205
206
207
208

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
209
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
210
211
212
213
214
215
216
217
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
218
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
219

Przemek Tredak's avatar
Przemek Tredak committed
220

221
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
222
    te_inp_hidden_states = torch.randn(
223
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
224
    ).cuda()
225
226
227
228

    if skip_wgrad:
        _disable_wgrads(block)

229
230
231
232
233
234
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

235
    use_fp8 = fp8_recipe is not None
236
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
237
        te_out = block(te_inp_hidden_states)
238
    te_out = sync_function(te_out)
239
240
241
242
243
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


244
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
245
    te_inp_hidden_states = torch.randn(
246
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
247
248
    ).cuda()

249
    te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
250
251
252
253

    if skip_wgrad:
        _disable_wgrads(block)

254
255
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
256
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
257
258
259
260
261
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


262
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
263
    te_inp_hidden_states = torch.randn(
264
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
265
    ).cuda()
266
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
267
    enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
268
269
270
271

    if skip_wgrad:
        _disable_wgrads(block)

272
273
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
274
        te_out = block(
275
276
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
277
278
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
279
280
281
282
283
284
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


285
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
286
287
288
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
289
    te_inp = torch.randn(
290
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
291
    ).cuda()
schetlur-nv's avatar
schetlur-nv committed
292
293
294
295

    if skip_wgrad:
        _disable_wgrads(block)

296
297
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
298
299
300
301
302
303
304
305
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


306
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
307
308
309
310
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
311
        config.seq_len, config.batch_size, config.hidden_size, requires_grad=True
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    ).cuda()
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
330
@pytest.mark.parametrize("model", ["small", "weird"])
331
332
333
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
334
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
335
336
337
338
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

    block = (
339
        module(config.hidden_size)
340
341
342
        .to(dtype=torch.float32)
        .cuda()
    )
343
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
344
345


Przemek Tredak's avatar
Przemek Tredak committed
346
347
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
348
@pytest.mark.parametrize("model", ["small", "weird"])
349
350
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
351
@pytest.mark.parametrize("skip_dgrad", all_boolean)
352
@pytest.mark.parametrize("normalization", all_normalizations)
353
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
354
355
                                 zero_centered_gamma, skip_dgrad,
                                 normalization):
356
357
358
359
360
361
362
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
363

Przemek Tredak's avatar
Przemek Tredak committed
364
365
366
367
368
369
370
371
    sigma = 0.023
    init_method = init_method_normal(sigma)

    block = (
        LayerNormLinear(
            config.hidden_size,
            config.hidden_size * 3,
            init_method=init_method,
372
            zero_centered_gamma=zero_centered_gamma,
373
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
374
375
376
377
        )
        .to(dtype=dtype)
        .cuda()
    )
378
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
379
380
381
382


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
383
@pytest.mark.parametrize("model", ["small", "weird"])
384
@pytest.mark.parametrize("skip_wgrad", all_boolean)
385
@pytest.mark.parametrize("skip_dgrad", all_boolean)
386
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
387
388
    config = model_configs[model]

389
390
391
392
393
394
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
395
396
397
398
399
400
401
402
403
404
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        Linear(
            config.hidden_size, config.hidden_size, init_method=output_layer_init_method
        )
        .to(dtype=dtype)
        .cuda()
    )
405
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
406
407
408
409


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
410
@pytest.mark.parametrize("model", ["small", "weird"])
411
412
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
413
@pytest.mark.parametrize("skip_dgrad", all_boolean)
414
@pytest.mark.parametrize("activation", all_activations)
415
@pytest.mark.parametrize("normalization", all_normalizations)
416
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
417
418
                              zero_centered_gamma, skip_dgrad, activation,
                              normalization):
419
420
421
422
423
424
425
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
426

Przemek Tredak's avatar
Przemek Tredak committed
427
428
429
430
431
432
433
434
435
436
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        LayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
437
            zero_centered_gamma=zero_centered_gamma,
438
            activation=activation,
439
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
440
441
442
443
        )
        .to(dtype=dtype)
        .cuda()
    )
444
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
445
446
447
448


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
449
@pytest.mark.parametrize("model", ["small"])
450
451
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
452
@pytest.mark.parametrize("bias", all_boolean)
453
@pytest.mark.parametrize("activation", all_activations)
454
@pytest.mark.parametrize("normalization", all_normalizations)
455
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
456
@pytest.mark.parametrize("cpu_offload", all_boolean)
457
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
458
                    zero_centered_gamma, bias, activation,
459
460
                    normalization, parallel_attention_mlp,
                    cpu_offload):
461
462
463
464
465
466
467
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
468

Przemek Tredak's avatar
Przemek Tredak committed
469
470
471
472
473
474
475
476
477
478
479
480
481
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
482
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
483
484
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
485
            zero_centered_gamma=zero_centered_gamma,
ngoyal2707's avatar
ngoyal2707 committed
486
            bias=bias,
487
            activation=activation,
488
            normalization=normalization,
489
            parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
490
491
492
493
494
        )
        .to(dtype=dtype)
        .cuda()
    )

495
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
            0,
            1,
            recipe.Format.E4M3,
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
518
        cpu_offload=False,
519
    )
Przemek Tredak's avatar
Przemek Tredak committed
520
521
522
523


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
524
@pytest.mark.parametrize("model", ["small"])
525
526
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
527
@pytest.mark.parametrize("normalization", all_normalizations)
528
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
529
                     normalization):
530
531
532
533
534
535
536
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
537

Przemek Tredak's avatar
Przemek Tredak committed
538
539
540
541
542
543
544
545
546
547
548
549
550
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
551
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
552
553
            apply_residual_connection_post_layernorm=True,
            output_layernorm=True,
554
            zero_centered_gamma=zero_centered_gamma,
555
            self_attn_mask_type="padding",
556
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
557
558
559
560
561
        )
        .to(dtype=dtype)
        .cuda()
    )

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
581
582
583
584


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
585
@pytest.mark.parametrize("model", ["small"])
586
587
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
588
@pytest.mark.parametrize("normalization", all_normalizations)
589
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
590
                   normalization):
591
592
593
594
595
596
597
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
598

Przemek Tredak's avatar
Przemek Tredak committed
599
600
601
602
603
604
605
606
607
608
609
610
611
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
612
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
613
614
615
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="decoder",
616
            zero_centered_gamma=zero_centered_gamma,
617
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
618
619
620
621
622
        )
        .to(dtype=dtype)
        .cuda()
    )

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
642
643
644
645


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
646
@pytest.mark.parametrize("model", ["small"])
647
@pytest.mark.parametrize("skip_wgrad", all_boolean)
648
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
649
650
    config = model_configs[model]

651
652
653
654
655
656
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
657
658
659
660
661
662
663
664
665
666
667
668
669
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
670
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
671
672
673
674
675
        )
        .to(dtype=torch.float32)
        .cuda()
    )

676
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
677
678
679
680


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
681
@pytest.mark.parametrize("model", ["small"])
682
@pytest.mark.parametrize("skip_wgrad", all_boolean)
683
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
684
685
    config = model_configs[model]

686
687
688
689
690
691
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
692
693
694
695
696
697
698
699
700
701
702
703
704
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
705
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
706
707
708
709
710
711
712
713
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            drop_path_rate=1.0,
        )
        .to(dtype=dtype)
        .cuda()
    )

714
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
715
716
717
718


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
719
@pytest.mark.parametrize("model", ["small"])
720
@pytest.mark.parametrize("skip_wgrad", all_boolean)
721
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
722
723
    config = model_configs[model]

724
725
726
727
728
729
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
730
731
732
733
734
735
736
737
738
739
740
741
742
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
743
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
744
745
746
747
748
749
750
751
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            fuse_qkv_params=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

752
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
753
754
755
756


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
757
@pytest.mark.parametrize("model", ["small"])
758
759
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
760
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
761
762
    config = model_configs[model]

763
764
765
766
767
768
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

769
770
771
772
773
774
775
776
777
778
779
780
781
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
782
            kv_channels=config.kv_channels,
783
784
785
786
787
788
789
790
791
792
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
            fuse_wgrad_accumulation=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

793
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
794
795
796
797


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
798
@pytest.mark.parametrize("model", ["small"])
799
800
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
801
@pytest.mark.parametrize("normalization", all_normalizations)
802
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
803
                        normalization):
804
805
806
807
808
809
810
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
811
812
813
814
815
816
817
818
819
820
821
822
823
824

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
825
            kv_channels=config.kv_channels,
826
827
828
829
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
830
            normalization=normalization,
831
832
833
834
835
        )
        .to(dtype=dtype)
        .cuda()
    )

836
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
837
838
839
840
841
842
843
844
845
846
847
848
849

def test_model_multiple_cast():
    a = torch.zeros((16,16)).cuda()
    m = Linear(16,32)

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16