test_sanity.py 26.1 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
#
# See LICENSE for license information.

import torch
import pytest

8
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
12
13
14
15
16
17
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
18
19
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
)
from transformer_engine.common import recipe

23
# Only run FP8 tests on H100.
24
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
25

Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values


class ModelConfig:
    def __init__(
        self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
    ):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

fp8_recipes = [
63
    None, # Handles non-FP8 case
Przemek Tredak's avatar
Przemek Tredak committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
    recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
    ),
    recipe.DelayedScaling(
        0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
        0,
        1,
        recipe.Format.E4M3,
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

91
92
93
param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
94
95
96

batch_sizes = [1, 2]

97
all_boolean = [True, False]
Przemek Tredak's avatar
Przemek Tredak committed
98

99
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
100
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
101
102
103
104
105
106

def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad):
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
    static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


schetlur-nv's avatar
schetlur-nv committed
155
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
156
157
158
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
    ).cuda()
159
    te_inp_hidden_states.retain_grad()
160
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
schetlur-nv's avatar
schetlur-nv committed
161
162
163
164

    if skip_wgrad:
        _disable_wgrads(block)

165
166
167
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
168
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
169
170
171
172
173
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

174
175
176
177
178
179
180
181
182
183
184
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad):
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
185
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
186
187
188
189
190
191
192
193
194
195
196
197

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
198
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
199
200
201
202
203
204
205
206
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
207
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
208

Przemek Tredak's avatar
Przemek Tredak committed
209

schetlur-nv's avatar
schetlur-nv committed
210
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
211
212
213
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()

    te_inp_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
232
233
234
235

    if skip_wgrad:
        _disable_wgrads(block)

236
237
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
238
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
239
240
241
242
243
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


schetlur-nv's avatar
schetlur-nv committed
244
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
245
246
247
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
248
249
    te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
    enc_dec_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5
schetlur-nv's avatar
schetlur-nv committed
250
251
252
253

    if skip_wgrad:
        _disable_wgrads(block)

254
255
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
256
        te_out = block(
257
258
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
259
260
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
261
262
263
264
265
266
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


267
268
269
270
def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
271
    te_inp = torch.randn(
272
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
273
    ).cuda()
schetlur-nv's avatar
schetlur-nv committed
274
275
276
277

    if skip_wgrad:
        _disable_wgrads(block)

278
279
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
280
281
282
283
284
285
286
287
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def _test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad):
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
        config.seq_len, bs, config.hidden_size, requires_grad=True
    ).cuda()
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_normalization_amp(dtype, bs, model, skip_wgrad, skip_dgrad, normalization):
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

    block = (
        module(
            config.hidden_size,
            eps=config.eps,
        )
        .to(dtype=torch.float32)
        .cuda()
    )
    _test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad)


Przemek Tredak's avatar
Przemek Tredak committed
332
333
334
335
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
336
337
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
338
@pytest.mark.parametrize("skip_dgrad", all_boolean)
339
340
341
342
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad,
                                 zero_centered_gamma, skip_dgrad,
                                 normalization):
343
    if fp8_recipe is not None and not fp8_available:
344
        pytest.skip(reason_for_no_fp8)
345

346
347
348
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
349
350
351
352
353
354
355
356
357
358
359
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)

    block = (
        LayerNormLinear(
            config.hidden_size,
            config.hidden_size * 3,
            eps=config.eps,
            init_method=init_method,
360
            zero_centered_gamma=zero_centered_gamma,
361
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
362
363
364
365
        )
        .to(dtype=dtype)
        .cuda()
    )
366
    _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
367
368
369
370
371
372


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
373
@pytest.mark.parametrize("skip_wgrad", all_boolean)
374
375
@pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad):
376
    if fp8_recipe is not None and not fp8_available:
377
        pytest.skip(reason_for_no_fp8)
378

Przemek Tredak's avatar
Przemek Tredak committed
379
380
381
382
383
384
385
386
387
388
389
390
    config = model_configs[model]

    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        Linear(
            config.hidden_size, config.hidden_size, init_method=output_layer_init_method
        )
        .to(dtype=dtype)
        .cuda()
    )
391
    _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
392
393
394
395
396
397


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
398
399
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
400
@pytest.mark.parametrize("skip_dgrad", all_boolean)
401
@pytest.mark.parametrize("activation", all_activations)
402
403
404
405
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
                              zero_centered_gamma, skip_dgrad, activation,
                              normalization):
406
    if fp8_recipe is not None and not fp8_available:
407
        pytest.skip(reason_for_no_fp8)
408

409
410
411
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
412
413
414
415
416
417
418
419
420
421
422
423
424
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        LayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
            eps=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
425
            zero_centered_gamma=zero_centered_gamma,
426
            activation=activation,
427
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
428
429
430
431
        )
        .to(dtype=dtype)
        .cuda()
    )
432
    _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
433
434
435
436
437
438


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
439
440
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
441
@pytest.mark.parametrize("bias", all_boolean)
442
@pytest.mark.parametrize("activation", all_activations)
443
444
445
446
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
                    zero_centered_gamma, bias, activation,
                    normalization):
447
    if fp8_recipe is not None and not fp8_available:
448
        pytest.skip(reason_for_no_fp8)
449

450
451
452
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
472
            zero_centered_gamma=zero_centered_gamma,
ngoyal2707's avatar
ngoyal2707 committed
473
            bias=bias,
474
            activation=activation,
475
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
476
477
478
479
480
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
481
    _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
482
483
484
485
486
487


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
488
489
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
490
491
492
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
                     normalization):
493
    if fp8_recipe is not None and not fp8_available:
494
        pytest.skip(reason_for_no_fp8)
495

496
497
498
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=True,
            output_layernorm=True,
518
            zero_centered_gamma=zero_centered_gamma,
519
            self_attn_mask_type="padding",
520
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
521
522
523
524
525
        )
        .to(dtype=dtype)
        .cuda()
    )

526
    _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
527
528
529
530
531
532


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
533
534
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
535
536
537
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
                   normalization):
538
    if fp8_recipe is not None and not fp8_available:
539
        pytest.skip(reason_for_no_fp8)
540

541
542
543
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

Przemek Tredak's avatar
Przemek Tredak committed
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="decoder",
564
            zero_centered_gamma=zero_centered_gamma,
565
            normalization=normalization,
Przemek Tredak's avatar
Przemek Tredak committed
566
567
568
569
570
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
571
    _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
572
573
574
575
576
577


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
578
@pytest.mark.parametrize("skip_wgrad", all_boolean)
schetlur-nv's avatar
schetlur-nv committed
579
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
580
    if fp8_recipe is not None and not fp8_available:
581
        pytest.skip(reason_for_no_fp8)
582

Przemek Tredak's avatar
Przemek Tredak committed
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
        )
        .to(dtype=torch.float32)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
605
    _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
606
607
608
609
610
611


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
612
@pytest.mark.parametrize("skip_wgrad", all_boolean)
schetlur-nv's avatar
schetlur-nv committed
613
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
614
    if fp8_recipe is not None and not fp8_available:
615
        pytest.skip(reason_for_no_fp8)
616

Przemek Tredak's avatar
Przemek Tredak committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            drop_path_rate=1.0,
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
642
    _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
643
644
645
646
647
648


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
649
@pytest.mark.parametrize("skip_wgrad", all_boolean)
schetlur-nv's avatar
schetlur-nv committed
650
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
651
    if fp8_recipe is not None and not fp8_available:
652
        pytest.skip(reason_for_no_fp8)
653

Przemek Tredak's avatar
Przemek Tredak committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            fuse_qkv_params=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

schetlur-nv's avatar
schetlur-nv committed
679
    _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
680
681
682
683
684
685
686
687
688
689


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
    if fp8_recipe is not None and not fp8_available:
690
        pytest.skip(reason_for_no_fp8)
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719

    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
            fuse_wgrad_accumulation=True,
        )
        .to(dtype=dtype)
        .cuda()
    )

    _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_recipe, skip_wgrad)
720
721
722
723
724
725
726
727


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
728
729
730
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
                        normalization):
731
732
733
    if fp8_recipe is not None and not fp8_available:
        pytest.skip(reason_for_no_fp8)

734
735
736
    if normalization == "RMSNorm" and zero_centered_gamma:
        pytest.skip("RMSNorm does not support zero_centered_gamma yet!")

737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            zero_centered_gamma=zero_centered_gamma,
            fuse_qkv_params=True,
758
            normalization=normalization,
759
760
761
762
763
764
        )
        .to(dtype=dtype)
        .cuda()
    )

    _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad)
765
766
767
768
769
770
771
772
773
774
775
776
777

def test_model_multiple_cast():
    a = torch.zeros((16,16)).cuda()
    m = Linear(16,32)

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16