test_sanity.py 28 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
7
from dataclasses import dataclass
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
8
9
10
import torch
import pytest

11
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
12
13
14
15
16
17
18
19
20
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
21
22
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
23
24
25
)
from transformer_engine.common import recipe

26
# Only run FP8 tests on H100.
27
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
28

Przemek Tredak's avatar
Przemek Tredak committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

48
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
49
class ModelConfig:
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
65
66

model_configs = {
67
68
69
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
Przemek Tredak's avatar
Przemek Tredak committed
70
71
72
}

fp8_recipes = [
73
    None, # Handles non-FP8 case
Przemek Tredak's avatar
Przemek Tredak committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
    recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

101
102
103
param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
104

105
all_boolean = [True, False]
Przemek Tredak's avatar
Przemek Tredak committed
106

107
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
108
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
109
110
111
112
113
114

def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


115
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
116
117
118
119
120
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
121
122
    static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


163
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
164
    te_inp_hidden_states = torch.randn(
165
        config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
166
    ).cuda()
167
    te_inp_hidden_states.retain_grad()
168
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
schetlur-nv's avatar
schetlur-nv committed
169
170
171
172

    if skip_wgrad:
        _disable_wgrads(block)

173
174
175
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
176
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
177
178
179
180
181
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

182
183
184
185
186
187
188
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


189
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
190
    te_inp_hidden_states = torch.randn(
191
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
192
    ).cuda()
193
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
194
195
196
197
198
199
200
201
202
203
204
205

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
206
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
207
208
209
210
211
212
213
214
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
215
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
216

Przemek Tredak's avatar
Przemek Tredak committed
217

218
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
219
    te_inp_hidden_states = torch.randn(
220
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
221
    ).cuda()
222
223
224
225
226
227
228
229
230
231
232
233

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


234
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
235
    te_inp_hidden_states = torch.randn(
236
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
237
238
    ).cuda()

239
    te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
240
241
242
243

    if skip_wgrad:
        _disable_wgrads(block)

244
245
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
246
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
247
248
249
250
251
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


252
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
253
    te_inp_hidden_states = torch.randn(
254
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
Przemek Tredak's avatar
Przemek Tredak committed
255
    ).cuda()
256
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
257
    enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
258
259
260
261

    if skip_wgrad:
        _disable_wgrads(block)

262
263
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
264
        te_out = block(
265
266
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
267
268
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
269
270
271
272
273
274
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


275
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
276
277
278
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
279
    te_inp = torch.randn(
280
        config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
281
    ).cuda()
schetlur-nv's avatar
schetlur-nv committed
282
283
284
285

    if skip_wgrad:
        _disable_wgrads(block)

286
287
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
288
289
290
291
292
293
294
295
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


296
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
297
298
299
300
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
301
        config.seq_len, config.batch_size, config.hidden_size, requires_grad=True
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    ).cuda()
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
320
@pytest.mark.parametrize("model", ["small", "weird"])
321
322
323
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
324
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
325
326
327
328
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

    block = (
329
        module(config.hidden_size)
330
331
332
        .to(dtype=torch.float32)
        .cuda()
    )
333
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
334
335


Przemek Tredak's avatar
Przemek Tredak committed
336
337
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
338
@pytest.mark.parametrize("model", ["small", "weird"])
339
340
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
341
@pytest.mark.parametrize("skip_dgrad", all_boolean)
342
@pytest.mark.parametrize("normalization", all_normalizations)
343
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
344
345
                                 zero_centered_gamma, skip_dgrad,
                                 normalization):
346
347
348
349
350
351
352
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
353

354
355
356
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
357
358
359
360
361
362
363
364
    sigma = 0.023
    init_method = init_method_normal(sigma)

    block = (
        LayerNormLinear(
            config.hidden_size,
            config.hidden_size * 3,
            init_method=init_method,
365
            zero_centered_gamma=zero_centered_gamma,
366
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
367
368
369
370
        )
        .to(dtype=dtype)
        .cuda()
    )
371
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
372
373
374
375


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
376
@pytest.mark.parametrize("model", ["small", "weird"])
377
@pytest.mark.parametrize("skip_wgrad", all_boolean)
378
@pytest.mark.parametrize("skip_dgrad", all_boolean)
379
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
380
381
    config = model_configs[model]

382
383
384
385
386
387
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
388
389
390
391
392
393
394
395
396
397
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        Linear(
            config.hidden_size, config.hidden_size, init_method=output_layer_init_method
        )
        .to(dtype=dtype)
        .cuda()
    )
398
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
399
400
401
402


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
403
@pytest.mark.parametrize("model", ["small", "weird"])
404
405
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
406
@pytest.mark.parametrize("skip_dgrad", all_boolean)
407
@pytest.mark.parametrize("activation", all_activations)
408
@pytest.mark.parametrize("normalization", all_normalizations)
409
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
410
411
                              zero_centered_gamma, skip_dgrad, activation,
                              normalization):
412
413
414
415
416
417
418
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
419

420
421
422
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
423
424
425
426
427
428
429
430
431
432
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        LayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
433
            zero_centered_gamma=zero_centered_gamma,
434
            activation=activation,
435
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
436
437
438
439
        )
        .to(dtype=dtype)
        .cuda()
    )
440
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
441
442
443
444


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
445
@pytest.mark.parametrize("model", ["small"])
446
447
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
448
@pytest.mark.parametrize("bias", all_boolean)
449
@pytest.mark.parametrize("activation", all_activations)
450
@pytest.mark.parametrize("normalization", all_normalizations)
451
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
452
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
453
                    zero_centered_gamma, bias, activation,
454
                    normalization, parallel_attention_mlp):
455
456
457
458
459
460
461
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
462

463
464
465
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
466
467
468
469
470
471
472
473
474
475
476
477
478
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
479
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
480
481
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
482
            zero_centered_gamma=zero_centered_gamma,
ngoyal2707's avatar
ngoyal2707 committed
483
            bias=bias,
484
            activation=activation,
485
            normalization=normalization,
486
            parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
487
488
489
490
491
        )
        .to(dtype=dtype)
        .cuda()
    )

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
            0,
            1,
            recipe.Format.E4M3,
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
516
517
518
519


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
520
@pytest.mark.parametrize("model", ["small"])
521
522
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
523
@pytest.mark.parametrize("normalization", all_normalizations)
524
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
525
                     normalization):
526
527
528
529
530
531
532
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
533

534
535
536
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
537
538
539
540
541
542
543
544
545
546
547
548
549
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
550
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
551
552
            apply_residual_connection_post_layernorm=True,
            output_layernorm=True,
553
            zero_centered_gamma=zero_centered_gamma,
554
            self_attn_mask_type="padding",
555
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
556
557
558
559
560
        )
        .to(dtype=dtype)
        .cuda()
    )

561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
580
581
582
583


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
584
@pytest.mark.parametrize("model", ["small"])
585
586
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
587
@pytest.mark.parametrize("normalization", all_normalizations)
588
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
589
                   normalization):
590
591
592
593
594
595
596
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
597

598
599
600
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
601
602
603
604
605
606
607
608
609
610
611
612
613
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
614
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
615
616
617
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="decoder",
618
            zero_centered_gamma=zero_centered_gamma,
619
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
620
621
622
623
624
        )
        .to(dtype=dtype)
        .cuda()
    )

625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
644
645
646
647


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
648
@pytest.mark.parametrize("model", ["small"])
649
@pytest.mark.parametrize("skip_wgrad", all_boolean)
650
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
651
652
    config = model_configs[model]

653
654
655
656
657
658
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
659
660
661
662
663
664
665
666
667
668
669
670
671
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
672
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
673
674
675
676
677
        )
        .to(dtype=torch.float32)
        .cuda()
    )

678
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
679
680
681
682


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
683
@pytest.mark.parametrize("model", ["small"])
684
@pytest.mark.parametrize("skip_wgrad", all_boolean)
685
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
686
687
    config = model_configs[model]

688
689
690
691
692
693
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
694
695
696
697
698
699
700
701
702
703
704
705
706
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
707
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
708
709
710
711
712
713
714
715
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            drop_path_rate=1.0,
        )
        .to(dtype=dtype)
        .cuda()
    )

716
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
717
718
719
720


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
721
@pytest.mark.parametrize("model", ["small"])
722
@pytest.mark.parametrize("skip_wgrad", all_boolean)
723
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
724
725
    config = model_configs[model]

726
727
728
729
730
731
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
732
733
734
735
736
737
738
739
740
741
742
743
744
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
745
            kv_channels=config.kv_channels,
Przemek Tredak's avatar
Przemek Tredak committed
746
747
748
749
750
751
752
753
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            fuse_qkv_params=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

754
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
755
756
757
758


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
759
@pytest.mark.parametrize("model", ["small"])
760
761
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
762
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
763
764
    config = model_configs[model]

765
766
767
768
769
770
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

771
772
773
774
775
776
777
778
779
780
781
782
783
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
784
            kv_channels=config.kv_channels,
785
786
787
788
789
790
791
792
793
794
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
            fuse_wgrad_accumulation=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

795
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
796
797
798
799


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
800
@pytest.mark.parametrize("model", ["small"])
801
802
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
803
@pytest.mark.parametrize("normalization", all_normalizations)
804
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
805
                        normalization):
806
807
808
809
810
811
812
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
813

814
815
816
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

817
818
819
820
821
822
823
824
825
826
827
828
829
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
830
            kv_channels=config.kv_channels,
831
832
833
834
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
835
            normalization=normalization,
836
837
838
839
840
        )
        .to(dtype=dtype)
        .cuda()
    )

841
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
842
843
844
845
846
847
848
849
850
851
852
853
854

def test_model_multiple_cast():
    a = torch.zeros((16,16)).cuda()
    m = Linear(16,32)

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16