test_sanity.py 38 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
Przemek Tredak's avatar
Przemek Tredak committed
10

11
import transformer_engine.pytorch
12
13
14
15
16
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
20
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
23
24
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
25
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
26
27
    LayerNormMLP,
    TransformerLayer,
28
29
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
30
31
)
from transformer_engine.common import recipe
32
import transformer_engine_torch as tex
33
from transformer_engine.pytorch.cpp_extensions import general_gemm
34
from transformer_engine.pytorch.module.base import get_workspace
35
36
37
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
38
39
    Float8Quantizer,
    Float8Tensor,
40
)
41
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
42
from transformer_engine.pytorch.tensor.utils import replace_raw_data
43
from transformer_engine.pytorch.distributed import checkpoint
44
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
45

46
# Only run FP8 tests on supported devices.
47
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
48
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
49
50
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

71

72
73
74
75
76
77
78
79
80
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
81

82

Przemek Tredak's avatar
Przemek Tredak committed
83
model_configs = {
84
85
86
87
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
88
89
}

90
91
92
93
94
95
96
97
98

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


99
100
101
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
102
    fp8_recipes.append(nvfp4_vanilla())  # TODO: fix check for this
103
104
105
106
107
108
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
109

110
param_types = [torch.float32, torch.float16]
111
if is_bf16_compatible():  # bf16 requires sm_80 or higher
112
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
113

114
all_boolean = [True, False]
115
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
116

117
118
119
120
121
122
123
124
125
126
127
128
all_activations = [
    "gelu",
    "geglu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
]
129
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
130

131

schetlur-nv's avatar
schetlur-nv committed
132
133
134
135
136
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


137
138
139
140
141
142
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


143
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
144
    te_inp_hidden_states = torch.randn(
145
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
146
147
148
149
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
150
    te_inp_hidden_states.retain_grad()
151
152
    te_inp_attn_mask = torch.randint(
        2,
153
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
154
155
156
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
157
158
159
160

    if skip_wgrad:
        _disable_wgrads(block)

161
162
163
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
164
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
165
166
167
168
169
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

170
    assert te_out.dtype == dtype, "AMP wrong output type."
171
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
172
173
174
175
176
177
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


178
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
179
    te_inp_hidden_states = torch.randn(
180
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
181
182
183
184
185
186
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
187
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
188
189
190
        dtype=torch.bool,
        device="cuda",
    )
191
192
193
194
195
196
197
198
199
200
201
202

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
203
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
204
205
206
207
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

208
    failed_grads = []
209
210
211
212
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
213
214
215
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
216

Przemek Tredak's avatar
Przemek Tredak committed
217

218
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
219
    te_inp_hidden_states = torch.randn(
220
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
221
222
223
224
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
225
226
227
228
229

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
230
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
231
232
233
234
235
236
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


237
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
238
    te_inp_hidden_states = torch.randn(
239
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
240
241
242
243
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
244

245
246
    te_inp_attn_mask = torch.randint(
        2,
247
        (config.batch_size, 1, 1, config.max_seqlen_q),
248
249
250
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
251
252
253
254

    if skip_wgrad:
        _disable_wgrads(block)

255
256
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
257
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
258
259
260
261
262
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


263
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
264
    te_inp_hidden_states = torch.randn(
265
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
266
267
268
269
270
271
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
272
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
273
274
275
276
277
278
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
279
        (config.batch_size, 1, 1, config.max_seqlen_kv),
280
281
282
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
283
284
285
286

    if skip_wgrad:
        _disable_wgrads(block)

287
288
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
289
        te_out = block(
290
291
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
292
293
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
294
295
296
297
298
299
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


300
301
302
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
303
304
305
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
306
    te_inp = torch.randn(
307
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
308
309
310
311
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
312
313
314
315

    if skip_wgrad:
        _disable_wgrads(block)

316
317
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
318
319
320
321
322
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
323
324
325
326
327
328
329
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


330
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
331
332
333
334
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
335
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
336
337
338
        device="cuda",
        requires_grad=True,
    )
339
340
341
342
343
344
345
346
347
348
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
349
    assert te_inp.grad is not None, "Gradient should not be empty"
350
351
352
353
354
355
356
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
357
@pytest.mark.parametrize("model", ["small", "weird"])
358
359
360
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
361
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
362
363
364
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

365
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
366
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
367
368


Przemek Tredak's avatar
Przemek Tredak committed
369
370
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
371
@pytest.mark.parametrize("model", ["small", "weird"])
372
373
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
374
@pytest.mark.parametrize("skip_dgrad", all_boolean)
375
@pytest.mark.parametrize("normalization", all_normalizations)
376
@pytest.mark.parametrize("microbatching", all_boolean)
377
def test_sanity_layernorm_linear(
378
379
380
381
382
383
384
385
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
386
):
387
388
389
    config = model_configs[model]

    if fp8_recipe is not None:
390
        if not is_fp8_supported(config):
391
            pytest.skip("Model config does not support FP8")
392
393
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
394

Przemek Tredak's avatar
Przemek Tredak committed
395
396
397
    sigma = 0.023
    init_method = init_method_normal(sigma)

398
399
400
401
402
403
404
405
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
406
    )
407
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
408
409
410
411


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
412
@pytest.mark.parametrize("model", ["small", "weird"])
413
@pytest.mark.parametrize("skip_wgrad", all_boolean)
414
@pytest.mark.parametrize("skip_dgrad", all_boolean)
415
416
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
417
418
    config = model_configs[model]

419
    if fp8_recipe is not None:
420
        if not is_fp8_supported(config):
421
            pytest.skip("Model config does not support FP8")
422
423
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
424

Przemek Tredak's avatar
Przemek Tredak committed
425
426
427
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

428
429
430
431
432
433
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
434
    )
435
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
436
437


438
439
440
441
442
443
444
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
445
446
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
447
448
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
449
    num_tokens = bs * config.max_seqlen_q
450
451

    if fp8_recipe is not None:
452
        if not is_fp8_supported(config):
453
            pytest.skip("Model config does not support FP8")
454
455
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
456
457

    use_fp8 = fp8_recipe is not None
458
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
459
460
461
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
462
463
464
465
466
467
468
469
470
471
472

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


473
474
475
476
477
478
479
480
481
482
483
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
484
485
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
486
487
488
489
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
490
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
491
492

    if fp8_recipe is not None:
493
        if not is_fp8_supported(config):
494
            pytest.skip("Model config does not support FP8")
495
496
        if fp8_recipe.nvfp4():
            pytest.skip("NVFP4 not supported for grouped linear")
497
498
499
500
501
502
503
504
505
506

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
507
    m_splits = [bs * config.max_seqlen_q] * num_gemms
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
522
523
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
524
@pytest.mark.parametrize("model", ["small", "weird"])
525
526
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
527
@pytest.mark.parametrize("skip_dgrad", all_boolean)
528
@pytest.mark.parametrize("activation", all_activations)
529
@pytest.mark.parametrize("normalization", all_normalizations)
530
@pytest.mark.parametrize("microbatching", all_boolean)
531
def test_sanity_layernorm_mlp(
532
533
534
535
536
537
538
539
540
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
541
):
542
543
544
    config = model_configs[model]

    if fp8_recipe is not None:
545
        if not is_fp8_supported(config):
546
            pytest.skip("Model config does not support FP8")
547
548
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
549

Przemek Tredak's avatar
Przemek Tredak committed
550
551
552
553
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

554
555
556
557
558
559
560
561
562
563
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
564
    )
565
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
566
567
568
569


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
570
@pytest.mark.parametrize("model", ["small"])
571
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
572
@pytest.mark.parametrize("bias", all_boolean)
573
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
574
@pytest.mark.parametrize("normalization", all_normalizations)
575
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
576
577
578
579
580
581
582
583
584
585
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
586
587
588
    config = model_configs[model]

    if fp8_recipe is not None:
589
        if not is_fp8_supported(config):
590
            pytest.skip("Model config does not support FP8")
591
592
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
593

Przemek Tredak's avatar
Przemek Tredak committed
594
595
596
597
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

598
599
600
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
601
        config.num_heads,
602
603
604
605
606
607
608
609
610
611
612
613
614
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
615
616
    )

617
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
618
619
620
621
622
623


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
624
625
            margin=0,
            fp8_format=recipe.Format.E4M3,
626
627
628
629
630
631
632
633
634
635
636
637
638
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
639
640
641
642


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
643
@pytest.mark.parametrize("model", ["small"])
644
@pytest.mark.parametrize("skip_wgrad", all_boolean)
645
@pytest.mark.parametrize("normalization", all_normalizations)
646
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
647
648
649
650
651
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
652
        if not is_fp8_supported(config):
653
            pytest.skip("Model config does not support FP8")
654
655
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
656

Przemek Tredak's avatar
Przemek Tredak committed
657
658
659
660
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

661
662
663
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
664
        config.num_heads,
665
666
667
668
669
670
671
672
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
673
        self_attn_mask_type="causal",
674
675
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
676
677
    )

678
679
680
681
682
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
683
684
        margin=0,
        fp8_format=recipe.Format.E4M3,
685
686
687
688
689
690
691
692
693
694
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
695
696
697
698


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
699
@pytest.mark.parametrize("model", ["small"])
700
@pytest.mark.parametrize("skip_wgrad", all_boolean)
701
@pytest.mark.parametrize("normalization", all_normalizations)
702
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
703
704
705
706
707
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
708
        if not is_fp8_supported(config):
709
            pytest.skip("Model config does not support FP8")
710
711
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
712

Przemek Tredak's avatar
Przemek Tredak committed
713
714
715
716
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

717
718
719
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
720
        config.num_heads,
721
722
723
724
725
726
727
728
729
730
731
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
732
733
    )

734
735
736
737
738
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
739
740
        margin=0,
        fp8_format=recipe.Format.E4M3,
741
742
743
744
745
746
747
748
749
750
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
751
752
753
754


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
755
@pytest.mark.parametrize("model", ["small"])
756
@pytest.mark.parametrize("skip_wgrad", all_boolean)
757
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
758
759
    config = model_configs[model]

760
    if fp8_recipe is not None:
761
        if not is_fp8_supported(config):
762
            pytest.skip("Model config does not support FP8")
763
764
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
765

Przemek Tredak's avatar
Przemek Tredak committed
766
767
768
769
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

770
771
772
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
773
        config.num_heads,
774
775
776
777
778
779
780
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
781
782
    )

783
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
784
785
786
787


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
788
@pytest.mark.parametrize("model", ["small"])
789
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
790
791
    config = model_configs[model]

792
    if fp8_recipe is not None:
793
        if not is_fp8_supported(config):
794
            pytest.skip("Model config does not support FP8")
795
796
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
797

Przemek Tredak's avatar
Przemek Tredak committed
798
799
800
801
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

802
803
804
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
805
        config.num_heads,
806
807
808
809
810
811
812
813
814
815
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
816
817
    )

818
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
819
820
821
822


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
823
@pytest.mark.parametrize("model", ["small"])
824
@pytest.mark.parametrize("skip_wgrad", all_boolean)
825
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
826
827
    config = model_configs[model]

828
    if fp8_recipe is not None:
829
        if not is_fp8_supported(config):
830
            pytest.skip("Model config does not support FP8")
831
832
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
833

Przemek Tredak's avatar
Przemek Tredak committed
834
835
836
837
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

838
839
840
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
841
        config.num_heads,
842
843
844
845
846
847
848
849
850
851
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
852
853
    )

854
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
855
856
857
858


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
859
@pytest.mark.parametrize("model", ["small"])
860
@pytest.mark.parametrize("skip_wgrad", all_boolean)
861
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
862
863
    config = model_configs[model]

864
    if fp8_recipe is not None:
865
        if not is_fp8_supported(config):
866
            pytest.skip("Model config does not support FP8")
867
868
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
869

870
871
872
873
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

874
875
876
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
877
        config.num_heads,
878
879
880
881
882
883
884
885
886
887
888
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
889
890
    )

891
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
892
893


894
def test_model_multiple_cast():
895
896
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
897
898
899
900
901
902
903
904
905

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
906
907
908
909
910
911


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
912
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
913
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
914
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
915

916
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
917
918
919
920
921
922
923
924
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
925
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
926

927
928
929
930
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
931
932
933

    outp_type = datatype

934
935
936
937
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
938
939
940
        weight_fp8,
        inp_fp8,
        get_workspace(),
941
        outp_type,
942
943
944
        bias=None,
        use_split_accumulator=False,
    )
945
    torch.cuda.synchronize()
946
947


948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1041
1042


1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
    with fp8_autocast():
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1082
1083
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()