test_sanity.py 31.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
import torch
import pytest

12
13
14
15
16
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
20
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
23
24
25
26
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
27
28
    RMSNorm,
    LayerNorm,
29
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
30
31
)
from transformer_engine.common import recipe
32
import transformer_engine_torch as tex
33
34
35
36
37
38
39
from transformer_engine.pytorch.cpp_extensions import (
    gemm,
    fp8_gemm,
    gelu,
    cast_to_fp8,
    cast_from_fp8,
)
40
41
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
Przemek Tredak's avatar
Przemek Tredak committed
42

43
# Only run FP8 tests on H100.
44
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
45

Przemek Tredak's avatar
Przemek Tredak committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

65

66
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
67
class ModelConfig:
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
83

84

Przemek Tredak's avatar
Przemek Tredak committed
85
model_configs = {
86
87
88
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
Przemek Tredak's avatar
Przemek Tredak committed
89
90
91
}

fp8_recipes = [
92
    None,  # Handles non-FP8 case
93
94
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
Przemek Tredak's avatar
Przemek Tredak committed
95
    recipe.DelayedScaling(
96
97
98
        margin=0,
        fp8_format=recipe.Format.E4M3,
        override_linear_precision=(False, False, True),
Przemek Tredak's avatar
Przemek Tredak committed
99
100
    ),
    recipe.DelayedScaling(
101
102
103
104
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="most_recent",
Przemek Tredak's avatar
Przemek Tredak committed
105
106
    ),
    recipe.DelayedScaling(
107
108
109
110
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="max",
Przemek Tredak's avatar
Przemek Tredak committed
111
112
    ),
    recipe.DelayedScaling(
113
114
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
115
116
117
118
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
119
120
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
121
122
123
124
125
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

126
param_types = [torch.float32, torch.float16]
127
if is_bf16_compatible():  # bf16 requires sm_80 or higher
128
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
129

130
all_boolean = [True, False]
131
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
132

133
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
134
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
135

136

schetlur-nv's avatar
schetlur-nv committed
137
138
139
140
141
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


142
143
144
145
146
147
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


148
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
149
150
151
152
153
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
154
155
156
157
158
159
160
161
162
163
164
    static_input = torch.randn(
        config.seq_len,
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
    )
    static_target = torch.randn(
        config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
    )
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
179
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
180
181
182
183
184
185
186
187
188
189
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
190
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


205
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
206
    te_inp_hidden_states = torch.randn(
207
208
209
210
211
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
212
    te_inp_hidden_states.retain_grad()
213
214
215
216
217
218
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
219
220
221
222

    if skip_wgrad:
        _disable_wgrads(block)

223
224
225
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
226
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
227
228
229
230
231
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

232
233
234
235
236
237
238
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


239
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
240
    te_inp_hidden_states = torch.randn(
241
242
243
244
245
246
247
248
249
250
251
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
252
253
254
255
256
257
258
259
260
261
262
263

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
264
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
265
266
267
268
269
270
271
272
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
273
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
274

Przemek Tredak's avatar
Przemek Tredak committed
275

276
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
277
    te_inp_hidden_states = torch.randn(
278
279
280
281
282
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
283
284
285
286

    if skip_wgrad:
        _disable_wgrads(block)

287
288
289
290
291
292
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

293
    use_fp8 = fp8_recipe is not None
294
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
295
        te_out = block(te_inp_hidden_states)
296
    te_out = sync_function(te_out)
297
298
299
300
301
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


302
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
303
    te_inp_hidden_states = torch.randn(
304
305
306
307
308
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
309

310
311
312
313
314
315
    te_inp_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
316
317
318
319

    if skip_wgrad:
        _disable_wgrads(block)

320
321
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
322
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
323
324
325
326
327
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


328
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
329
    te_inp_hidden_states = torch.randn(
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
348
349
350
351

    if skip_wgrad:
        _disable_wgrads(block)

352
353
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
354
        te_out = block(
355
356
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
357
358
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
359
360
361
362
363
364
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


365
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
366
367
368
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
369
    te_inp = torch.randn(
370
371
372
373
374
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
375
376
377
378

    if skip_wgrad:
        _disable_wgrads(block)

379
380
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
381
382
383
384
385
386
387
388
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


389
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
390
391
392
393
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
394
395
396
397
        (config.seq_len, config.batch_size, config.hidden_size),
        device="cuda",
        requires_grad=True,
    )
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
415
@pytest.mark.parametrize("model", ["small", "weird"])
416
417
418
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
419
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
420
421
422
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

423
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
424
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
425
426


Przemek Tredak's avatar
Przemek Tredak committed
427
428
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
429
@pytest.mark.parametrize("model", ["small", "weird"])
430
431
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
432
@pytest.mark.parametrize("skip_dgrad", all_boolean)
433
@pytest.mark.parametrize("normalization", all_normalizations)
434
435
436
def test_sanity_layernorm_linear(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
):
437
438
439
440
441
442
443
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
444

Przemek Tredak's avatar
Przemek Tredak committed
445
446
447
    sigma = 0.023
    init_method = init_method_normal(sigma)

448
449
450
451
452
453
454
455
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
456
    )
457
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
458
459
460
461


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
462
@pytest.mark.parametrize("model", ["small", "weird"])
463
@pytest.mark.parametrize("skip_wgrad", all_boolean)
464
@pytest.mark.parametrize("skip_dgrad", all_boolean)
465
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
466
467
    config = model_configs[model]

468
469
470
471
472
473
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
474
475
476
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

477
478
479
480
481
482
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
483
    )
484
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
485
486


487
488
489
490
491
492
493
494
495
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
496
    num_tokens = bs * config.seq_len
497
498
499
500
501
502
503
504
505

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params):
506
507
508
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
509
510
511
512
513
514
515
516
517
518
519

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
520
521
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
522
@pytest.mark.parametrize("model", ["small", "weird"])
523
524
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
525
@pytest.mark.parametrize("skip_dgrad", all_boolean)
526
@pytest.mark.parametrize("activation", all_activations)
527
@pytest.mark.parametrize("normalization", all_normalizations)
528
529
530
def test_sanity_layernorm_mlp(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
):
531
532
533
534
535
536
537
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
538

Przemek Tredak's avatar
Przemek Tredak committed
539
540
541
542
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

543
544
545
546
547
548
549
550
551
552
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
553
    )
554
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
555
556
557
558


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
559
@pytest.mark.parametrize("model", ["small"])
560
561
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
562
@pytest.mark.parametrize("bias", all_boolean)
563
@pytest.mark.parametrize("activation", all_activations)
564
@pytest.mark.parametrize("normalization", all_normalizations)
565
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
566
@pytest.mark.parametrize("cpu_offload", all_boolean)
567
568
569
570
571
572
573
574
575
576
577
578
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
    cpu_offload,
):
579
580
581
582
583
584
585
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
586

Przemek Tredak's avatar
Przemek Tredak committed
587
588
589
590
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
609
610
    )

611
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
612
613
614
615
616
617


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
618
619
            margin=0,
            fp8_format=recipe.Format.E4M3,
620
621
622
623
624
625
626
627
628
629
630
631
632
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
633
        cpu_offload=False,
634
    )
Przemek Tredak's avatar
Przemek Tredak committed
635
636
637
638


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
639
@pytest.mark.parametrize("model", ["small"])
640
641
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
642
@pytest.mark.parametrize("normalization", all_normalizations)
643
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
644
645
646
647
648
649
650
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
651

Przemek Tredak's avatar
Przemek Tredak committed
652
653
654
655
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
        zero_centered_gamma=zero_centered_gamma,
        self_attn_mask_type="padding",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
672
673
    )

674
675
676
677
678
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
679
680
        margin=0,
        fp8_format=recipe.Format.E4M3,
681
682
683
684
685
686
687
688
689
690
691
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
692
693
694
695


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
696
@pytest.mark.parametrize("model", ["small"])
697
698
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
699
@pytest.mark.parametrize("normalization", all_normalizations)
700
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
701
702
703
704
705
706
707
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
708

Przemek Tredak's avatar
Przemek Tredak committed
709
710
711
712
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
729
730
    )

731
732
733
734
735
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
736
737
        margin=0,
        fp8_format=recipe.Format.E4M3,
738
739
740
741
742
743
744
745
746
747
748
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
749
750
751
752


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
753
@pytest.mark.parametrize("model", ["small"])
754
@pytest.mark.parametrize("skip_wgrad", all_boolean)
755
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
756
757
    config = model_configs[model]

758
759
760
761
762
763
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
764
765
766
767
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

768
769
770
771
772
773
774
775
776
777
778
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
779
780
    )

781
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
782
783
784
785


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
786
@pytest.mark.parametrize("model", ["small"])
787
@pytest.mark.parametrize("skip_wgrad", all_boolean)
788
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
789
790
    config = model_configs[model]

791
792
793
794
795
796
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
797
798
799
800
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

801
802
803
804
805
806
807
808
809
810
811
812
813
814
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
815
816
    )

817
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
818
819
820
821


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
822
@pytest.mark.parametrize("model", ["small"])
823
@pytest.mark.parametrize("skip_wgrad", all_boolean)
824
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
825
826
    config = model_configs[model]

827
828
829
830
831
832
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
833
834
835
836
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

837
838
839
840
841
842
843
844
845
846
847
848
849
850
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
851
852
    )

853
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
854
855
856
857


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
858
@pytest.mark.parametrize("model", ["small"])
859
860
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
861
862
863
def test_sanity_gradient_accumulation_fusion(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
864
865
    config = model_configs[model]

866
867
868
869
870
871
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

872
873
874
875
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
892
893
    )

894
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
895
896
897
898


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
899
@pytest.mark.parametrize("model", ["small"])
900
901
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
902
@pytest.mark.parametrize("normalization", all_normalizations)
903
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
904
905
906
907
908
909
910
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
911
912
913
914
915

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        normalization=normalization,
        device="cuda",
932
933
    )

934
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
935

936

937
def test_model_multiple_cast():
938
939
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
940
941
942
943
944
945
946
947
948

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
949
950
951
952
953
954


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
955
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
956
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
957
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
958

959
    _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace())
960
961
962
963
964
965
966
967
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
968
    scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype)
969
970
971
972
973

    fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
    fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT

    nb_inp_scales, nb_weight_scales = 1, N
974
    scale_factor = 1.0
975
976
977
978
979
980
    meta_inp = create_meta(scale_factor, nb_inp_scales)
    meta_weight = create_meta(scale_factor, nb_weight_scales)
    inp_type = tex.DType.kFloat8E4M3
    weights_type = tex.DType.kFloat8E4M3
    outp_type = datatype

981
    scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type)
982
983
984
    inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
    _, _ = fp8_gemm(
985
986
987
988
989
990
991
992
993
994
995
996
997
998
        weight_fp8,
        meta_weight.scale_inv,
        fp8_tensor_weight,
        inp_type,
        inp_fp8,
        meta_inp.scale_inv,
        fp8_tensor_inp,
        weights_type,
        outp_type,
        get_workspace(),
        bias=None,
        use_bias=False,
        use_split_accumulator=False,
    )
999
    torch.cuda.synchronize()