test_sanity.py 49.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
import torch
import pytest
11
import os
yuguo's avatar
yuguo committed
12
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
13

14
import transformer_engine.pytorch
15
16
17
18
19
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
20
from transformer_engine.pytorch.utils import (
21
    get_device_compute_capability,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
    init_method_normal,
    scaled_init_method_normal,
24
    is_bf16_compatible,
25
    get_cudnn_version,
Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
29
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
30
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
31
32
    LayerNormMLP,
    TransformerLayer,
33
34
    RMSNorm,
    LayerNorm,
35
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
36
37
)
from transformer_engine.common import recipe
38
import transformer_engine_torch as tex
39
from transformer_engine.pytorch.cpp_extensions import general_gemm
40
from transformer_engine.pytorch.module.base import get_workspace
41
42
43
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
44
45
    Float8Quantizer,
    Float8Tensor,
46
)
47
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
48
from transformer_engine.pytorch.tensor.utils import replace_raw_data
49
from transformer_engine.pytorch.distributed import checkpoint
50
from utils import dtype_tols
Przemek Tredak's avatar
Przemek Tredak committed
51

52
# Only run FP8 tests on supported devices.
53
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
54
55
56
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
57
58
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

81
82
83
84
85
86
87

def create_meta(scale_factor: float, size: int = 1):
    meta = tex.FP8TensorMeta()
    meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
    meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
    meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
    return meta
88

yuguo's avatar
yuguo committed
89
90
91
92
93
94
if IS_HIP_EXTENSION:
    from functools import cache
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )
Przemek Tredak's avatar
Przemek Tredak committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

114

115
116
117
118
119
120
121
def reset_rng_states() -> None:
    """revert back to initial RNG state."""
    global _cpu_rng_state, _cuda_rng_state
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


122
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
123
class ModelConfig:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
139

140

Przemek Tredak's avatar
Przemek Tredak committed
141
model_configs = {
142
143
144
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
145
    "large": ModelConfig(1, 128, 2, 512, 4, 128),
Przemek Tredak's avatar
Przemek Tredak committed
146
147
148
}

fp8_recipes = [
149
150
151
    None,  # Test non-FP8
    recipe.MXFP8BlockScaling(),  # Test default
    recipe.Float8CurrentScaling(),  # Test default
152
    recipe.Float8BlockScaling(),  # Test default
153
154
    recipe.DelayedScaling(),  # Test default
    recipe.DelayedScaling(  # Test most_recent algo
155
156
        amax_history_len=16,
        amax_compute_algo="most_recent",
Przemek Tredak's avatar
Przemek Tredak committed
157
    ),
158
    recipe.DelayedScaling(  # Test custom amax and scale compute algo
159
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
160
161
162
163
164
        amax_compute_algo=custom_amax_compute,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

165
param_types = [torch.float32, torch.float16]
166
if is_bf16_compatible():  # bf16 requires sm_80 or higher
167
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
168

169
all_boolean = [True, False]
170
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
171

172
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
173
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
174

175

schetlur-nv's avatar
schetlur-nv committed
176
177
178
179
180
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


181
182
183
184
185
186
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


187
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
188
189
190
191
192
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
193
194
195
196
197
198
199
200
201
202
203
    static_input = torch.randn(
        config.seq_len,
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
    )
    static_target = torch.randn(
        config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
    )
204
205
206
207
208
209
210
211
212
213
214
215
216
217

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
218
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
219
220
221
222
223
224
225
226
227
228
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
229
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


244
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
245
    te_inp_hidden_states = torch.randn(
246
247
248
249
250
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
251
    te_inp_hidden_states.retain_grad()
252
253
254
255
256
257
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
258
259
260
261

    if skip_wgrad:
        _disable_wgrads(block)

262
263
264
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
265
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
266
267
268
269
270
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

271
    assert te_out.dtype == dtype, "AMP wrong output type."
272
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
273
274
275
276
277
278
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


279
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
280
    te_inp_hidden_states = torch.randn(
281
282
283
284
285
286
287
288
289
290
291
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
292
293
294
295
296
297
298
299
300
301
302
303

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
304
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
305
306
307
308
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

309
    failed_grads = []
310
311
312
313
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
314
315
316
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
317

Przemek Tredak's avatar
Przemek Tredak committed
318

319
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
320
    te_inp_hidden_states = torch.randn(
321
322
323
324
325
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
326
327
328
329

    if skip_wgrad:
        _disable_wgrads(block)

330
331
332
333
334
335
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

336
    use_fp8 = fp8_recipe is not None
337
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
338
        te_out = block(te_inp_hidden_states)
339
    te_out = sync_function(te_out)
340
341
342
343
344
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


345
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
346
    te_inp_hidden_states = torch.randn(
347
348
349
350
351
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
352

353
354
355
356
357
358
    te_inp_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
359
360
361
362

    if skip_wgrad:
        _disable_wgrads(block)

363
364
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
365
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
366
367
368
369
370
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


371
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
372
    te_inp_hidden_states = torch.randn(
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
391
392
393
394

    if skip_wgrad:
        _disable_wgrads(block)

395
396
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
397
        te_out = block(
398
399
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
400
401
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
402
403
404
405
406
407
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


408
409
410
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
411
412
413
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
414
    te_inp = torch.randn(
415
416
417
418
419
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
420
421
422
423

    if skip_wgrad:
        _disable_wgrads(block)

424
425
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
426
427
428
429
430
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
431
432
433
434
435
436
437
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


438
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
439
440
441
442
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
443
444
445
446
        (config.seq_len, config.batch_size, config.hidden_size),
        device="cuda",
        requires_grad=True,
    )
447
448
449
450
451
452
453
454
455
456
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
457
    assert te_inp.grad is not None, "Gradient should not be empty"
458
459
460
461
462
463
464
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
465
@pytest.mark.parametrize("model", ["small", "weird"])
466
467
468
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
469
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
470
471
472
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

473
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
474
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
475
476


Przemek Tredak's avatar
Przemek Tredak committed
477
478
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
479
@pytest.mark.parametrize("model", ["small", "weird"])
480
481
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
482
@pytest.mark.parametrize("skip_dgrad", all_boolean)
483
@pytest.mark.parametrize("normalization", all_normalizations)
484
@pytest.mark.parametrize("microbatching", all_boolean)
485
def test_sanity_layernorm_linear(
486
487
488
489
490
491
492
493
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
494
):
495
496
497
498
499
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
500
501
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
502
503
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
504
505
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
506

Przemek Tredak's avatar
Przemek Tredak committed
507
508
509
    sigma = 0.023
    init_method = init_method_normal(sigma)

510
511
512
513
514
515
516
517
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
518
    )
519
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
520
521
522
523


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
524
@pytest.mark.parametrize("model", ["small", "weird"])
525
@pytest.mark.parametrize("skip_wgrad", all_boolean)
526
@pytest.mark.parametrize("skip_dgrad", all_boolean)
527
528
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
529
530
    config = model_configs[model]

531
532
533
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
534
535
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
536
537
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
538
539
540
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
541
542
543
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

544
545
546
547
548
549
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
550
    )
551
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
552
553


554
555
556
557
558
559
560
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
561
562
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
563
564
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
565
    num_tokens = bs * config.seq_len
566
567
568
569

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
570
571
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
572
573
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
574
575
576
577
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
578
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
579
580
581
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
582
583
584
585
586
587
588
589
590
591
592

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


593
594
595
596
597
598
599
600
601
602
603
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
604
605
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
606
607
608
609
610
611
612
613
614
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
    num_tokens = bs * config.seq_len * (num_gemms - 1)

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
615
616
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
617
618
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    m_splits = [bs * config.seq_len] * num_gemms
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
646
647
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
648
@pytest.mark.parametrize("model", ["small", "weird"])
649
650
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
651
@pytest.mark.parametrize("skip_dgrad", all_boolean)
652
@pytest.mark.parametrize("activation", all_activations)
653
@pytest.mark.parametrize("normalization", all_normalizations)
654
@pytest.mark.parametrize("microbatching", all_boolean)
655
def test_sanity_layernorm_mlp(
656
657
658
659
660
661
662
663
664
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
665
):
666
667
668
669
670
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
671
672
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
673
674
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
675
676
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
677

Przemek Tredak's avatar
Przemek Tredak committed
678
679
680
681
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

682
683
684
685
686
687
688
689
690
691
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
692
    )
693
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
694
695
696
697


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
698
@pytest.mark.parametrize("model", ["small"])
699
700
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
701
@pytest.mark.parametrize("bias", all_boolean)
702
@pytest.mark.parametrize("activation", all_activations)
703
@pytest.mark.parametrize("normalization", all_normalizations)
704
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
705
@pytest.mark.parametrize("cpu_offload", all_boolean)
706
707
708
709
710
711
712
713
714
715
716
717
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
    cpu_offload,
):
718
719
    if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("CPU offload is not supported in debug mode.")
720
721
722
723
724
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
725
726
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
727
728
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
729
730
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
731

Przemek Tredak's avatar
Przemek Tredak committed
732
733
734
735
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
754
755
    )

756
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
757
758
759
760
761
762


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
763
764
            margin=0,
            fp8_format=recipe.Format.E4M3,
765
766
767
768
769
770
771
772
773
774
775
776
777
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
778
        cpu_offload=False,
779
    )
Przemek Tredak's avatar
Przemek Tredak committed
780
781
782
783


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
784
@pytest.mark.parametrize("model", ["small"])
785
786
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
787
@pytest.mark.parametrize("normalization", all_normalizations)
788
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
789
790
791
792
793
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
794
795
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
796
797
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
798
799
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
800

Przemek Tredak's avatar
Przemek Tredak committed
801
802
803
804
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

805
806
807
808
809
810
811
812
813
814
815
816
817
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
        zero_centered_gamma=zero_centered_gamma,
818
        self_attn_mask_type="causal",
819
820
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
821
822
    )

823
824
825
826
827
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
828
829
        margin=0,
        fp8_format=recipe.Format.E4M3,
830
831
832
833
834
835
836
837
838
839
840
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
841
842
843
844


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
845
@pytest.mark.parametrize("model", ["small"])
846
847
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
848
@pytest.mark.parametrize("normalization", all_normalizations)
849
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
850
851
852
853
854
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
855
856
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
857
858
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
859
860
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
861

Przemek Tredak's avatar
Przemek Tredak committed
862
863
864
865
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
882
883
    )

884
885
886
887
888
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
889
890
        margin=0,
        fp8_format=recipe.Format.E4M3,
891
892
893
894
895
896
897
898
899
900
901
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
902
903
904
905


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
906
@pytest.mark.parametrize("model", ["small"])
907
@pytest.mark.parametrize("skip_wgrad", all_boolean)
908
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
909
910
    config = model_configs[model]

911
912
913
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
914
915
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
916
917
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
918
919
920
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
921
922
923
924
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

925
926
927
928
929
930
931
932
933
934
935
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
936
937
    )

938
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
939
940
941
942


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
943
@pytest.mark.parametrize("model", ["small"])
944
@pytest.mark.parametrize("skip_wgrad", all_boolean)
945
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
946
947
    config = model_configs[model]

948
949
950
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
951
952
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
953
954
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
955
956
957
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
958
959
960
961
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

962
963
964
965
966
967
968
969
970
971
972
973
974
975
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
976
977
    )

978
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
979
980
981
982


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
983
@pytest.mark.parametrize("model", ["small"])
984
@pytest.mark.parametrize("skip_wgrad", all_boolean)
985
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
986
987
    config = model_configs[model]

988
989
990
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
991
992
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
993
994
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
995
996
997
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
998
999
1000
1001
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
1016
1017
    )

1018
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
1019
1020
1021
1022


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
1023
@pytest.mark.parametrize("model", ["small"])
1024
1025
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1026
1027
1028
def test_sanity_gradient_accumulation_fusion(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
1029
1030
    config = model_configs[model]

1031
1032
1033
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
1034
1035
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
1036
1037
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
1038
1039
1040
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

1041
1042
1043
1044
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
1061
1062
    )

1063
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
1064
1065
1066
1067


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
1068
@pytest.mark.parametrize("model", ["small"])
1069
1070
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1071
@pytest.mark.parametrize("normalization", all_normalizations)
1072
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
yuguo's avatar
yuguo committed
1073
1074
1075
1076
    if IS_HIP_EXTENSION:
        if not use_hipblaslt():
            pytest.skip("CUDA graph capture not supported with rocBLAS path")
    
1077
1078
1079
1080
1081
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
1082
1083
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
1084
1085
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
1086
1087
        if fp8_recipe.float8_block_scaling():
            pytest.skip("cuda graph not supported for float8_block_scaling recipe")
1088
1089
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
1090
1091
1092
1093
1094

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        normalization=normalization,
        device="cuda",
1111
1112
    )

1113
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
1114

1115

1116
def test_model_multiple_cast():
1117
1118
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
1119
1120
1121
1122
1123
1124
1125
1126
1127

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
1128
1129
1130
1131
1132
1133


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
1134
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
1135
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
1136
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
1137

1138
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
1139
1140
1141
1142
1143
1144
1145
1146
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
1147
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
1148

1149
1150
1151
1152
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
1153
1154
1155

    outp_type = datatype

1156
1157
1158
1159
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
1160
1161
1162
        weight_fp8,
        inp_fp8,
        get_workspace(),
1163
        outp_type,
1164
1165
1166
        bias=None,
        use_split_accumulator=False,
    )
1167
    torch.cuda.synchronize()
1168
1169
1170


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1171
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
1172
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
1173
1174
1175
1176
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
    config = model_configs[model]
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_attention_extra_state(
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
1205
1206
1207
1208
1209
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
1210
        fp8_dpa=fp8_enabled,
1211
1212
        fp8_mha=False,
    )
1213
1214

    reset_rng_states()
1215
1216
1217
1218
1219
1220
1221
    hidden_states = torch.randn(
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

1222
1223
1224
1225
1226
    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1227
        with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
1228
1229
1230
1231
1232
1233
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_attention_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
1234
1235
                hidden_dropout=0.0,
                attention_dropout=0.0,
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
        with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
1268
        block.load_state_dict(torch.load(path, weights_only=False))
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
        with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
1293

1294
    return outputs
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404


@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1405
1406
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()