test_sanity.py 40.5 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
yuguo's avatar
yuguo committed
10
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
11

12
13
14
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
15
16
17
18
19
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
20
21
    autocast,
    quantized_model_init,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
    LayerNormLinear,
    Linear,
24
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
25
26
    LayerNormMLP,
    TransformerLayer,
27
28
    RMSNorm,
    LayerNorm,
29
30
31
32
33
34
35
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
    Float8Tensor,
    MXFP8Tensor,
    checkpoint,
    QuantizedTensor,
    is_bf16_available,
Przemek Tredak's avatar
Przemek Tredak committed
36
37
)
from transformer_engine.common import recipe
38
import transformer_engine_torch as tex
39
from transformer_engine.pytorch.cpp_extensions import general_gemm
40
from transformer_engine.pytorch.tensor.utils import replace_raw_data
41
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
42

43
# Only run FP8 tests on supported devices.
44
45
46
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

68

69
70
71
72
73
74
75
76
77
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
78

79

Przemek Tredak's avatar
Przemek Tredak committed
80
model_configs = {
81
82
83
84
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
85
86
}

87
88
89
90
91
92
93
94
95

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


96
97
98
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
99
    fp8_recipes.append(nvfp4_vanilla())  # TODO: fix check for this
100
101
102
103
104
105
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
106

107
param_types = [torch.float32, torch.float16]
108
if is_bf16_available():  # bf16 requires sm_80 or higher
109
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
110

111
all_boolean = [True, False]
112
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
113

114
115
116
117
118
119
120
121
122
123
124
all_activations = [
    "gelu",
    "geglu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
125
    "clamped_swiglu",
126
]
127
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
128

129

schetlur-nv's avatar
schetlur-nv committed
130
131
132
133
134
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


135
136
137
138
139
140
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


141
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
142
    te_inp_hidden_states = torch.randn(
143
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
144
145
146
147
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
148
    te_inp_hidden_states.retain_grad()
149
150
    te_inp_attn_mask = torch.randint(
        2,
151
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
152
153
154
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
155
156
157
158

    if skip_wgrad:
        _disable_wgrads(block)

159
160
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
161
        with autocast(enabled=use_fp8, recipe=fp8_recipe):
162
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
163
164
165
166
167
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

168
    assert te_out.dtype == dtype, "AMP wrong output type."
169
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
170
171
172
173
174
175
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


176
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
177
    te_inp_hidden_states = torch.randn(
178
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
179
180
181
182
183
184
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
185
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
186
187
188
        dtype=torch.bool,
        device="cuda",
    )
189
190
191
192
193
194
195
196
197
198
199

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
200
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
201
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
202
203
204
205
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

206
    failed_grads = []
207
208
209
210
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
211
212
213
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
214

Przemek Tredak's avatar
Przemek Tredak committed
215

216
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
217
    te_inp_hidden_states = torch.randn(
218
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
219
220
221
222
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
223
224
225
226
227

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
228
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
229
230
231
232
233
234
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


235
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
236
    te_inp_hidden_states = torch.randn(
237
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
238
239
240
241
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
242

243
244
    te_inp_attn_mask = torch.randint(
        2,
245
        (config.batch_size, 1, 1, config.max_seqlen_q),
246
247
248
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
249
250
251
252

    if skip_wgrad:
        _disable_wgrads(block)

253
    use_fp8 = fp8_recipe is not None
254
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
255
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
256
257
258
259
260
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


261
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
262
    te_inp_hidden_states = torch.randn(
263
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
264
265
266
267
268
269
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
270
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
271
272
273
274
275
276
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
277
        (config.batch_size, 1, 1, config.max_seqlen_kv),
278
279
280
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
281
282
283
284

    if skip_wgrad:
        _disable_wgrads(block)

285
    use_fp8 = fp8_recipe is not None
286
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
287
        te_out = block(
288
289
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
290
291
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
292
293
294
295
296
297
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


298
299
300
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
301
302
303
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
304
    te_inp = torch.randn(
305
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
306
307
308
309
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
310
311
312
313

    if skip_wgrad:
        _disable_wgrads(block)

314
    use_fp8 = fp8_recipe is not None
315
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
316
317
318
319
320
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
321
322
323
324
325
326
327
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


328
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
329
330
331
332
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
333
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
334
335
336
        device="cuda",
        requires_grad=True,
    )
337
338
339
340
341
342
343
344
345
346
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
347
    assert te_inp.grad is not None, "Gradient should not be empty"
348
349
350
351
352
353
354
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
355
@pytest.mark.parametrize("model", ["small", "weird"])
356
357
358
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
359
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
360
361
362
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

363
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
364
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
365
366


Przemek Tredak's avatar
Przemek Tredak committed
367
368
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
369
@pytest.mark.parametrize("model", ["small", "weird"])
370
371
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
372
@pytest.mark.parametrize("skip_dgrad", all_boolean)
373
@pytest.mark.parametrize("normalization", all_normalizations)
374
@pytest.mark.parametrize("microbatching", all_boolean)
375
def test_sanity_layernorm_linear(
376
377
378
379
380
381
382
383
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
384
):
385
386
387
    config = model_configs[model]

    if fp8_recipe is not None:
388
389
390
391
392
393
394
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
395
            pytest.skip("Model config does not support FP8")
396
397
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
398

Przemek Tredak's avatar
Przemek Tredak committed
399
400
401
    sigma = 0.023
    init_method = init_method_normal(sigma)

402
403
404
405
406
407
408
409
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
410
    )
411
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
412
413
414
415


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
416
@pytest.mark.parametrize("model", ["small", "weird"])
417
@pytest.mark.parametrize("skip_wgrad", all_boolean)
418
@pytest.mark.parametrize("skip_dgrad", all_boolean)
419
420
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
421
422
    config = model_configs[model]

423
    if fp8_recipe is not None:
424
        if not is_fp8_supported(config):
425
            pytest.skip("Model config does not support FP8")
426
427
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
428

Przemek Tredak's avatar
Przemek Tredak committed
429
430
431
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

432
433
434
435
436
437
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
438
    )
439
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
440
441


442
443
444
445
446
447
448
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
449
450
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
451
452
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
453
    num_tokens = bs * config.max_seqlen_q
454
455

    if fp8_recipe is not None:
456
457
458
459
460
461
462
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
463
            pytest.skip("Model config does not support FP8")
464
465
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
466
467

    use_fp8 = fp8_recipe is not None
468
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
469
470
471
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
472
473
474
475

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
476
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
477
478
479
480
481
482
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


483
484
485
486
487
488
489
490
491
492
493
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
494
495
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
496
497
498
499
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
500
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
501
502

    if fp8_recipe is not None:
503
        if not is_fp8_supported(config):
504
            pytest.skip("Model config does not support FP8")
505
506
        if fp8_recipe.nvfp4():
            pytest.skip("NVFP4 not supported for grouped linear")
507
508

    use_fp8 = fp8_recipe is not None
509
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
510
511
512
513
514
515
516
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
517
    m_splits = [bs * config.max_seqlen_q] * num_gemms
518
519
520
521
522
523
524
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

525
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
526
527
528
529
530
531
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
532
533
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
534
@pytest.mark.parametrize("model", ["small", "weird"])
535
536
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
537
@pytest.mark.parametrize("skip_dgrad", all_boolean)
538
@pytest.mark.parametrize("activation", all_activations)
539
@pytest.mark.parametrize("normalization", all_normalizations)
540
@pytest.mark.parametrize("microbatching", all_boolean)
541
@pytest.mark.parametrize("checkpoint", all_boolean)
542
def test_sanity_layernorm_mlp(
543
544
545
546
547
548
549
550
551
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
552
    checkpoint,
553
):
554
555
556
    config = model_configs[model]

    if fp8_recipe is not None:
557
558
559
560
561
562
563
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
564
            pytest.skip("Model config does not support FP8")
565
566
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
567

Przemek Tredak's avatar
Przemek Tredak committed
568
569
570
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
571
    activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
572
573
574
575
576
577
578
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
579
        activation_params=activation_params,
580
581
582
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
583
        checkpoint=checkpoint,
Przemek Tredak's avatar
Przemek Tredak committed
584
    )
585
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
586
587
588
589


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
590
@pytest.mark.parametrize("model", ["small"])
591
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
592
@pytest.mark.parametrize("bias", all_boolean)
593
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
594
@pytest.mark.parametrize("normalization", all_normalizations)
595
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
596
597
598
599
600
601
602
603
604
605
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
606
607
608
    config = model_configs[model]

    if fp8_recipe is not None:
609
610
611
612
613
614
615
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
616
            pytest.skip("Model config does not support FP8")
617
618
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
619

Przemek Tredak's avatar
Przemek Tredak committed
620
621
622
623
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

624
625
626
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
627
        config.num_heads,
628
629
630
631
632
633
634
635
636
637
638
639
640
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
641
642
    )

643
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
644
645
646
647
648
649


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
650
651
            margin=0,
            fp8_format=recipe.Format.E4M3,
652
653
654
655
656
657
658
659
660
661
662
663
664
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
665
666
667
668


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
669
@pytest.mark.parametrize("model", ["small"])
670
@pytest.mark.parametrize("skip_wgrad", all_boolean)
671
@pytest.mark.parametrize("normalization", all_normalizations)
672
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
673
674
675
676
677
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
678
        if not is_fp8_supported(config):
679
            pytest.skip("Model config does not support FP8")
680
681
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
682

Przemek Tredak's avatar
Przemek Tredak committed
683
684
685
686
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

687
688
689
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
690
        config.num_heads,
691
692
693
694
695
696
697
698
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
699
        self_attn_mask_type="causal",
700
701
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
702
703
    )

704
705
706
707
708
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
709
710
        margin=0,
        fp8_format=recipe.Format.E4M3,
711
712
713
714
715
716
717
718
719
720
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
721
722
723
724


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
725
@pytest.mark.parametrize("model", ["small"])
726
@pytest.mark.parametrize("skip_wgrad", all_boolean)
727
@pytest.mark.parametrize("normalization", all_normalizations)
728
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
729
730
731
732
733
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
734
        if not is_fp8_supported(config):
735
            pytest.skip("Model config does not support FP8")
736
737
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
738

Przemek Tredak's avatar
Przemek Tredak committed
739
740
741
742
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

743
744
745
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
746
        config.num_heads,
747
748
749
750
751
752
753
754
755
756
757
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
758
759
    )

760
761
762
763
764
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
765
766
        margin=0,
        fp8_format=recipe.Format.E4M3,
767
768
769
770
771
772
773
774
775
776
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
777
778
779
780


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
781
@pytest.mark.parametrize("model", ["small"])
782
@pytest.mark.parametrize("skip_wgrad", all_boolean)
783
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
784
785
    config = model_configs[model]

786
    if fp8_recipe is not None:
787
788
789
790
791
792
793
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
794
            pytest.skip("Model config does not support FP8")
795
796
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
797

Przemek Tredak's avatar
Przemek Tredak committed
798
799
800
801
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

802
803
804
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
805
        config.num_heads,
806
807
808
809
810
811
812
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
813
814
    )

815
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
816
817
818
819


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
820
@pytest.mark.parametrize("model", ["small"])
821
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
822
823
    config = model_configs[model]

824
    if fp8_recipe is not None:
825
826
827
828
829
830
831
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
832
            pytest.skip("Model config does not support FP8")
833
834
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
835

Przemek Tredak's avatar
Przemek Tredak committed
836
837
838
839
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

840
841
842
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
843
        config.num_heads,
844
845
846
847
848
849
850
851
852
853
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
854
855
    )

856
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
857
858
859
860


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
861
@pytest.mark.parametrize("model", ["small"])
862
@pytest.mark.parametrize("skip_wgrad", all_boolean)
863
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
864
865
    config = model_configs[model]

866
    if fp8_recipe is not None:
867
868
869
870
871
872
873
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
874
            pytest.skip("Model config does not support FP8")
875
876
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
877

Przemek Tredak's avatar
Przemek Tredak committed
878
879
880
881
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

882
883
884
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
885
        config.num_heads,
886
887
888
889
890
891
892
893
894
895
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
896
897
    )

898
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
899
900
901
902


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
903
@pytest.mark.parametrize("model", ["small"])
904
@pytest.mark.parametrize("skip_wgrad", all_boolean)
905
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
906
907
    config = model_configs[model]

908
    if fp8_recipe is not None:
909
910
911
912
913
914
915
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
916
            pytest.skip("Model config does not support FP8")
917
918
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
919

920
921
922
923
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

924
925
926
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
927
        config.num_heads,
928
929
930
931
932
933
934
935
936
937
938
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
939
940
    )

941
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
942
943


944
def test_model_multiple_cast():
945
946
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
947
948
949
950
951
952
953
954
955

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
956
957
958
959
960
961


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
962
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
963
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
964
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
965

966
    _ = general_gemm(A=weight, B=inp)
967
968
969
970
971
972
973
974
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
975
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
976

977
978
979
980
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
981
982
983

    outp_type = datatype

984
985
986
987
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
988
989
        weight_fp8,
        inp_fp8,
990
        outp_type,
991
992
993
        bias=None,
        use_split_accumulator=False,
    )
994
    torch.cuda.synchronize()
995
996


997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1028
1029
1030
def test_quantized_model_init_high_precision_init_val():
    """Test quantized_model_init with preserve_high_precision_init_val=True"""
    with quantized_model_init(preserve_high_precision_init_val=True):
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1090
1091


1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
1103
    with autocast():
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1131
1132
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
1157
        with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe):
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
1192
        with autocast(enabled=with_quantization, recipe=quantization_recipe):
1193
1194
            y = module(x, **kwargs)
    check_weights()