test_sanity.py 38 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
Przemek Tredak's avatar
Przemek Tredak committed
10

11
12
13
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
14
15
16
17
18
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
19
20
    autocast,
    quantized_model_init,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
    LayerNormLinear,
    Linear,
23
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
24
25
    LayerNormMLP,
    TransformerLayer,
26
27
    RMSNorm,
    LayerNorm,
28
29
30
31
32
33
34
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
    Float8Tensor,
    MXFP8Tensor,
    checkpoint,
    QuantizedTensor,
    is_bf16_available,
Przemek Tredak's avatar
Przemek Tredak committed
35
36
)
from transformer_engine.common import recipe
37
import transformer_engine_torch as tex
38
from transformer_engine.pytorch.cpp_extensions import general_gemm
39
from transformer_engine.pytorch.module.base import get_workspace
40
from transformer_engine.pytorch.tensor.utils import replace_raw_data
41
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
42

43
# Only run FP8 tests on supported devices.
44
45
46
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

68

69
70
71
72
73
74
75
76
77
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
78

79

Przemek Tredak's avatar
Przemek Tredak committed
80
model_configs = {
81
82
83
84
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
85
86
}

87
88
89
90
91
92
93
94
95

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


96
97
98
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
99
    fp8_recipes.append(nvfp4_vanilla())  # TODO: fix check for this
100
101
102
103
104
105
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
106

107
param_types = [torch.float32, torch.float16]
108
if is_bf16_available():  # bf16 requires sm_80 or higher
109
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
110

111
all_boolean = [True, False]
112
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
113

114
115
116
117
118
119
120
121
122
123
124
all_activations = [
    "gelu",
    "geglu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
125
    "clamped_swiglu",
126
]
127
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
128

129

schetlur-nv's avatar
schetlur-nv committed
130
131
132
133
134
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


135
136
137
138
139
140
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


141
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
142
    te_inp_hidden_states = torch.randn(
143
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
144
145
146
147
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
148
    te_inp_hidden_states.retain_grad()
149
150
    te_inp_attn_mask = torch.randint(
        2,
151
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
152
153
154
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
155
156
157
158

    if skip_wgrad:
        _disable_wgrads(block)

159
160
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
161
        with autocast(enabled=use_fp8, recipe=fp8_recipe):
162
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
163
164
165
166
167
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

168
    assert te_out.dtype == dtype, "AMP wrong output type."
169
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
170
171
172
173
174
175
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


176
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
177
    te_inp_hidden_states = torch.randn(
178
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
179
180
181
182
183
184
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
185
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
186
187
188
        dtype=torch.bool,
        device="cuda",
    )
189
190
191
192
193
194
195
196
197
198
199

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
200
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
201
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
202
203
204
205
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

206
    failed_grads = []
207
208
209
210
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
211
212
213
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
214

Przemek Tredak's avatar
Przemek Tredak committed
215

216
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
217
    te_inp_hidden_states = torch.randn(
218
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
219
220
221
222
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
223
224
225
226
227

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
228
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
229
230
231
232
233
234
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


235
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
236
    te_inp_hidden_states = torch.randn(
237
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
238
239
240
241
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
242

243
244
    te_inp_attn_mask = torch.randint(
        2,
245
        (config.batch_size, 1, 1, config.max_seqlen_q),
246
247
248
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
249
250
251
252

    if skip_wgrad:
        _disable_wgrads(block)

253
    use_fp8 = fp8_recipe is not None
254
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
255
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
256
257
258
259
260
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


261
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
262
    te_inp_hidden_states = torch.randn(
263
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
264
265
266
267
268
269
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
270
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
271
272
273
274
275
276
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
277
        (config.batch_size, 1, 1, config.max_seqlen_kv),
278
279
280
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
281
282
283
284

    if skip_wgrad:
        _disable_wgrads(block)

285
    use_fp8 = fp8_recipe is not None
286
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
287
        te_out = block(
288
289
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
290
291
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
292
293
294
295
296
297
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


298
299
300
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
301
302
303
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
304
    te_inp = torch.randn(
305
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
306
307
308
309
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
310
311
312
313

    if skip_wgrad:
        _disable_wgrads(block)

314
    use_fp8 = fp8_recipe is not None
315
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
316
317
318
319
320
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
321
322
323
324
325
326
327
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


328
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
329
330
331
332
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
333
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
334
335
336
        device="cuda",
        requires_grad=True,
    )
337
338
339
340
341
342
343
344
345
346
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
347
    assert te_inp.grad is not None, "Gradient should not be empty"
348
349
350
351
352
353
354
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
355
@pytest.mark.parametrize("model", ["small", "weird"])
356
357
358
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
359
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
360
361
362
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

363
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
364
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
365
366


Przemek Tredak's avatar
Przemek Tredak committed
367
368
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
369
@pytest.mark.parametrize("model", ["small", "weird"])
370
371
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
372
@pytest.mark.parametrize("skip_dgrad", all_boolean)
373
@pytest.mark.parametrize("normalization", all_normalizations)
374
@pytest.mark.parametrize("microbatching", all_boolean)
375
def test_sanity_layernorm_linear(
376
377
378
379
380
381
382
383
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
384
):
385
386
387
    config = model_configs[model]

    if fp8_recipe is not None:
388
        if not is_fp8_supported(config):
389
            pytest.skip("Model config does not support FP8")
390
391
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
392

Przemek Tredak's avatar
Przemek Tredak committed
393
394
395
    sigma = 0.023
    init_method = init_method_normal(sigma)

396
397
398
399
400
401
402
403
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
404
    )
405
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
406
407
408
409


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
410
@pytest.mark.parametrize("model", ["small", "weird"])
411
@pytest.mark.parametrize("skip_wgrad", all_boolean)
412
@pytest.mark.parametrize("skip_dgrad", all_boolean)
413
414
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
415
416
    config = model_configs[model]

417
    if fp8_recipe is not None:
418
        if not is_fp8_supported(config):
419
            pytest.skip("Model config does not support FP8")
420
421
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
422

Przemek Tredak's avatar
Przemek Tredak committed
423
424
425
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

426
427
428
429
430
431
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
432
    )
433
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
434
435


436
437
438
439
440
441
442
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
443
444
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
445
446
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
447
    num_tokens = bs * config.max_seqlen_q
448
449

    if fp8_recipe is not None:
450
        if not is_fp8_supported(config):
451
            pytest.skip("Model config does not support FP8")
452
453
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
454
455

    use_fp8 = fp8_recipe is not None
456
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
457
458
459
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
460
461
462
463

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
464
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
465
466
467
468
469
470
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


471
472
473
474
475
476
477
478
479
480
481
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
482
483
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
484
485
486
487
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
488
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
489
490

    if fp8_recipe is not None:
491
        if not is_fp8_supported(config):
492
            pytest.skip("Model config does not support FP8")
493
494
        if fp8_recipe.nvfp4():
            pytest.skip("NVFP4 not supported for grouped linear")
495
496

    use_fp8 = fp8_recipe is not None
497
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
498
499
500
501
502
503
504
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
505
    m_splits = [bs * config.max_seqlen_q] * num_gemms
506
507
508
509
510
511
512
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

513
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
514
515
516
517
518
519
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
520
521
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
522
@pytest.mark.parametrize("model", ["small", "weird"])
523
524
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
525
@pytest.mark.parametrize("skip_dgrad", all_boolean)
526
@pytest.mark.parametrize("activation", all_activations)
527
@pytest.mark.parametrize("normalization", all_normalizations)
528
@pytest.mark.parametrize("microbatching", all_boolean)
529
def test_sanity_layernorm_mlp(
530
531
532
533
534
535
536
537
538
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
539
):
540
541
542
    config = model_configs[model]

    if fp8_recipe is not None:
543
        if not is_fp8_supported(config):
544
            pytest.skip("Model config does not support FP8")
545
546
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
547

Przemek Tredak's avatar
Przemek Tredak committed
548
549
550
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
551
    activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
552
553
554
555
556
557
558
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
559
        activation_params=activation_params,
560
561
562
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
563
    )
564
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
565
566
567
568


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
569
@pytest.mark.parametrize("model", ["small"])
570
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
571
@pytest.mark.parametrize("bias", all_boolean)
572
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
573
@pytest.mark.parametrize("normalization", all_normalizations)
574
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
575
576
577
578
579
580
581
582
583
584
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
585
586
587
    config = model_configs[model]

    if fp8_recipe is not None:
588
        if not is_fp8_supported(config):
589
            pytest.skip("Model config does not support FP8")
590
591
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
592

Przemek Tredak's avatar
Przemek Tredak committed
593
594
595
596
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

597
598
599
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
600
        config.num_heads,
601
602
603
604
605
606
607
608
609
610
611
612
613
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
614
615
    )

616
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
617
618
619
620
621
622


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
623
624
            margin=0,
            fp8_format=recipe.Format.E4M3,
625
626
627
628
629
630
631
632
633
634
635
636
637
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
638
639
640
641


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
642
@pytest.mark.parametrize("model", ["small"])
643
@pytest.mark.parametrize("skip_wgrad", all_boolean)
644
@pytest.mark.parametrize("normalization", all_normalizations)
645
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
646
647
648
649
650
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
651
        if not is_fp8_supported(config):
652
            pytest.skip("Model config does not support FP8")
653
654
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
655

Przemek Tredak's avatar
Przemek Tredak committed
656
657
658
659
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

660
661
662
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
663
        config.num_heads,
664
665
666
667
668
669
670
671
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
672
        self_attn_mask_type="causal",
673
674
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
675
676
    )

677
678
679
680
681
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
682
683
        margin=0,
        fp8_format=recipe.Format.E4M3,
684
685
686
687
688
689
690
691
692
693
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
694
695
696
697


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
698
@pytest.mark.parametrize("model", ["small"])
699
@pytest.mark.parametrize("skip_wgrad", all_boolean)
700
@pytest.mark.parametrize("normalization", all_normalizations)
701
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
702
703
704
705
706
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
707
        if not is_fp8_supported(config):
708
            pytest.skip("Model config does not support FP8")
709
710
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
711

Przemek Tredak's avatar
Przemek Tredak committed
712
713
714
715
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

716
717
718
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
719
        config.num_heads,
720
721
722
723
724
725
726
727
728
729
730
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
731
732
    )

733
734
735
736
737
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
738
739
        margin=0,
        fp8_format=recipe.Format.E4M3,
740
741
742
743
744
745
746
747
748
749
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
750
751
752
753


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
754
@pytest.mark.parametrize("model", ["small"])
755
@pytest.mark.parametrize("skip_wgrad", all_boolean)
756
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
757
758
    config = model_configs[model]

759
    if fp8_recipe is not None:
760
        if not is_fp8_supported(config):
761
            pytest.skip("Model config does not support FP8")
762
763
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
764

Przemek Tredak's avatar
Przemek Tredak committed
765
766
767
768
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

769
770
771
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
772
        config.num_heads,
773
774
775
776
777
778
779
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
780
781
    )

782
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
783
784
785
786


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
787
@pytest.mark.parametrize("model", ["small"])
788
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
789
790
    config = model_configs[model]

791
    if fp8_recipe is not None:
792
        if not is_fp8_supported(config):
793
            pytest.skip("Model config does not support FP8")
794
795
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
796

Przemek Tredak's avatar
Przemek Tredak committed
797
798
799
800
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

801
802
803
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
804
        config.num_heads,
805
806
807
808
809
810
811
812
813
814
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
815
816
    )

817
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
818
819
820
821


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
822
@pytest.mark.parametrize("model", ["small"])
823
@pytest.mark.parametrize("skip_wgrad", all_boolean)
824
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
825
826
    config = model_configs[model]

827
    if fp8_recipe is not None:
828
        if not is_fp8_supported(config):
829
            pytest.skip("Model config does not support FP8")
830
831
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
832

Przemek Tredak's avatar
Przemek Tredak committed
833
834
835
836
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

837
838
839
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
840
        config.num_heads,
841
842
843
844
845
846
847
848
849
850
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
851
852
    )

853
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
854
855
856
857


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
858
@pytest.mark.parametrize("model", ["small"])
859
@pytest.mark.parametrize("skip_wgrad", all_boolean)
860
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
861
862
    config = model_configs[model]

863
    if fp8_recipe is not None:
864
        if not is_fp8_supported(config):
865
            pytest.skip("Model config does not support FP8")
866
867
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
868

869
870
871
872
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

873
874
875
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
876
        config.num_heads,
877
878
879
880
881
882
883
884
885
886
887
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
888
889
    )

890
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
891
892


893
def test_model_multiple_cast():
894
895
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
896
897
898
899
900
901
902
903
904

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
905
906
907
908
909
910


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
911
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
912
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
913
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
914

915
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
916
917
918
919
920
921
922
923
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
924
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
925

926
927
928
929
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
930
931
932

    outp_type = datatype

933
934
935
936
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
937
938
939
        weight_fp8,
        inp_fp8,
        get_workspace(),
940
        outp_type,
941
942
943
        bias=None,
        use_split_accumulator=False,
    )
944
    torch.cuda.synchronize()
945
946


947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
978
979
980
def test_quantized_model_init_high_precision_init_val():
    """Test quantized_model_init with preserve_high_precision_init_val=True"""
    with quantized_model_init(preserve_high_precision_init_val=True):
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1040
1041


1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
1053
    with autocast():
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1081
1082
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
1107
        with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe):
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
1142
        with autocast(enabled=with_quantization, recipe=quantization_recipe):
1143
1144
            y = module(x, **kwargs)
    check_weights()