test_sanity.py 38 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
Przemek Tredak's avatar
Przemek Tredak committed
10

11
12
13
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
14
15
16
17
18
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
19
20
    autocast,
    quantized_model_init,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
    LayerNormLinear,
    Linear,
23
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
24
25
    LayerNormMLP,
    TransformerLayer,
26
27
    RMSNorm,
    LayerNorm,
28
29
30
31
32
33
34
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
    Float8Tensor,
    MXFP8Tensor,
    checkpoint,
    QuantizedTensor,
    is_bf16_available,
Przemek Tredak's avatar
Przemek Tredak committed
35
36
)
from transformer_engine.common import recipe
37
import transformer_engine_torch as tex
38
from transformer_engine.pytorch.cpp_extensions import general_gemm
39
from transformer_engine.pytorch.tensor.utils import replace_raw_data
40
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
41

42
# Only run FP8 tests on supported devices.
43
44
45
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

67

68
69
70
71
72
73
74
75
76
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
77

78

Przemek Tredak's avatar
Przemek Tredak committed
79
model_configs = {
80
81
82
83
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
84
85
}

86
87
88
89
90
91
92
93
94

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


95
96
97
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
98
    fp8_recipes.append(nvfp4_vanilla())  # TODO: fix check for this
99
100
101
102
103
104
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
105

106
param_types = [torch.float32, torch.float16]
107
if is_bf16_available():  # bf16 requires sm_80 or higher
108
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
109

110
all_boolean = [True, False]
111
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
112

113
114
115
116
117
118
119
120
121
122
123
all_activations = [
    "gelu",
    "geglu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
124
    "clamped_swiglu",
125
]
126
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
127

128

schetlur-nv's avatar
schetlur-nv committed
129
130
131
132
133
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


134
135
136
137
138
139
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


140
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
141
    te_inp_hidden_states = torch.randn(
142
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
143
144
145
146
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
147
    te_inp_hidden_states.retain_grad()
148
149
    te_inp_attn_mask = torch.randint(
        2,
150
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
151
152
153
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
154
155
156
157

    if skip_wgrad:
        _disable_wgrads(block)

158
159
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
160
        with autocast(enabled=use_fp8, recipe=fp8_recipe):
161
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
162
163
164
165
166
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

167
    assert te_out.dtype == dtype, "AMP wrong output type."
168
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
169
170
171
172
173
174
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


175
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
176
    te_inp_hidden_states = torch.randn(
177
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
178
179
180
181
182
183
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
184
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
185
186
187
        dtype=torch.bool,
        device="cuda",
    )
188
189
190
191
192
193
194
195
196
197
198

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
199
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
200
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
201
202
203
204
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

205
    failed_grads = []
206
207
208
209
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
210
211
212
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
213

Przemek Tredak's avatar
Przemek Tredak committed
214

215
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
216
    te_inp_hidden_states = torch.randn(
217
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
218
219
220
221
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
222
223
224
225
226

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
227
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
228
229
230
231
232
233
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


234
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
235
    te_inp_hidden_states = torch.randn(
236
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
237
238
239
240
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
241

242
243
    te_inp_attn_mask = torch.randint(
        2,
244
        (config.batch_size, 1, 1, config.max_seqlen_q),
245
246
247
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
248
249
250
251

    if skip_wgrad:
        _disable_wgrads(block)

252
    use_fp8 = fp8_recipe is not None
253
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
254
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
255
256
257
258
259
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


260
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
261
    te_inp_hidden_states = torch.randn(
262
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
263
264
265
266
267
268
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
269
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
270
271
272
273
274
275
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
276
        (config.batch_size, 1, 1, config.max_seqlen_kv),
277
278
279
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
280
281
282
283

    if skip_wgrad:
        _disable_wgrads(block)

284
    use_fp8 = fp8_recipe is not None
285
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
286
        te_out = block(
287
288
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
289
290
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
291
292
293
294
295
296
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


297
298
299
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
300
301
302
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
303
    te_inp = torch.randn(
304
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
305
306
307
308
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
309
310
311
312

    if skip_wgrad:
        _disable_wgrads(block)

313
    use_fp8 = fp8_recipe is not None
314
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
315
316
317
318
319
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
320
321
322
323
324
325
326
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


327
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
328
329
330
331
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
332
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
333
334
335
        device="cuda",
        requires_grad=True,
    )
336
337
338
339
340
341
342
343
344
345
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
346
    assert te_inp.grad is not None, "Gradient should not be empty"
347
348
349
350
351
352
353
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
354
@pytest.mark.parametrize("model", ["small", "weird"])
355
356
357
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
358
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
359
360
361
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

362
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
363
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
364
365


Przemek Tredak's avatar
Przemek Tredak committed
366
367
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
368
@pytest.mark.parametrize("model", ["small", "weird"])
369
370
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
371
@pytest.mark.parametrize("skip_dgrad", all_boolean)
372
@pytest.mark.parametrize("normalization", all_normalizations)
373
@pytest.mark.parametrize("microbatching", all_boolean)
374
def test_sanity_layernorm_linear(
375
376
377
378
379
380
381
382
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
383
):
384
385
386
    config = model_configs[model]

    if fp8_recipe is not None:
387
        if not is_fp8_supported(config):
388
            pytest.skip("Model config does not support FP8")
389
390
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
391

Przemek Tredak's avatar
Przemek Tredak committed
392
393
394
    sigma = 0.023
    init_method = init_method_normal(sigma)

395
396
397
398
399
400
401
402
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
403
    )
404
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
405
406
407
408


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
409
@pytest.mark.parametrize("model", ["small", "weird"])
410
@pytest.mark.parametrize("skip_wgrad", all_boolean)
411
@pytest.mark.parametrize("skip_dgrad", all_boolean)
412
413
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
414
415
    config = model_configs[model]

416
    if fp8_recipe is not None:
417
        if not is_fp8_supported(config):
418
            pytest.skip("Model config does not support FP8")
419
420
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
421

Przemek Tredak's avatar
Przemek Tredak committed
422
423
424
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

425
426
427
428
429
430
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
431
    )
432
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
433
434


435
436
437
438
439
440
441
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
442
443
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
444
445
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
446
    num_tokens = bs * config.max_seqlen_q
447
448

    if fp8_recipe is not None:
449
        if not is_fp8_supported(config):
450
            pytest.skip("Model config does not support FP8")
451
452
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
453
454

    use_fp8 = fp8_recipe is not None
455
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
456
457
458
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
459
460
461
462

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
463
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
464
465
466
467
468
469
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


470
471
472
473
474
475
476
477
478
479
480
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
481
482
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
483
484
485
486
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
487
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
488
489

    if fp8_recipe is not None:
490
        if not is_fp8_supported(config):
491
            pytest.skip("Model config does not support FP8")
492
493
        if fp8_recipe.nvfp4():
            pytest.skip("NVFP4 not supported for grouped linear")
494
495

    use_fp8 = fp8_recipe is not None
496
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
497
498
499
500
501
502
503
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
504
    m_splits = [bs * config.max_seqlen_q] * num_gemms
505
506
507
508
509
510
511
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

512
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
513
514
515
516
517
518
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
519
520
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
521
@pytest.mark.parametrize("model", ["small", "weird"])
522
523
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
524
@pytest.mark.parametrize("skip_dgrad", all_boolean)
525
@pytest.mark.parametrize("activation", all_activations)
526
@pytest.mark.parametrize("normalization", all_normalizations)
527
@pytest.mark.parametrize("microbatching", all_boolean)
528
@pytest.mark.parametrize("checkpoint", all_boolean)
529
def test_sanity_layernorm_mlp(
530
531
532
533
534
535
536
537
538
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
539
    checkpoint,
540
):
541
542
543
    config = model_configs[model]

    if fp8_recipe is not None:
544
        if not is_fp8_supported(config):
545
            pytest.skip("Model config does not support FP8")
546
547
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
548

Przemek Tredak's avatar
Przemek Tredak committed
549
550
551
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
552
    activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
553
554
555
556
557
558
559
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
560
        activation_params=activation_params,
561
562
563
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
564
        checkpoint=checkpoint,
Przemek Tredak's avatar
Przemek Tredak committed
565
    )
566
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
567
568
569
570


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
571
@pytest.mark.parametrize("model", ["small"])
572
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
573
@pytest.mark.parametrize("bias", all_boolean)
574
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
575
@pytest.mark.parametrize("normalization", all_normalizations)
576
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
577
578
579
580
581
582
583
584
585
586
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
587
588
589
    config = model_configs[model]

    if fp8_recipe is not None:
590
        if not is_fp8_supported(config):
591
            pytest.skip("Model config does not support FP8")
592
593
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
594

Przemek Tredak's avatar
Przemek Tredak committed
595
596
597
598
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

599
600
601
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
602
        config.num_heads,
603
604
605
606
607
608
609
610
611
612
613
614
615
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
616
617
    )

618
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
619
620
621
622
623
624


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
625
626
            margin=0,
            fp8_format=recipe.Format.E4M3,
627
628
629
630
631
632
633
634
635
636
637
638
639
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
640
641
642
643


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
644
@pytest.mark.parametrize("model", ["small"])
645
@pytest.mark.parametrize("skip_wgrad", all_boolean)
646
@pytest.mark.parametrize("normalization", all_normalizations)
647
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
648
649
650
651
652
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
653
        if not is_fp8_supported(config):
654
            pytest.skip("Model config does not support FP8")
655
656
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
657

Przemek Tredak's avatar
Przemek Tredak committed
658
659
660
661
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

662
663
664
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
665
        config.num_heads,
666
667
668
669
670
671
672
673
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
674
        self_attn_mask_type="causal",
675
676
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
677
678
    )

679
680
681
682
683
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
684
685
        margin=0,
        fp8_format=recipe.Format.E4M3,
686
687
688
689
690
691
692
693
694
695
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
696
697
698
699


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
700
@pytest.mark.parametrize("model", ["small"])
701
@pytest.mark.parametrize("skip_wgrad", all_boolean)
702
@pytest.mark.parametrize("normalization", all_normalizations)
703
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
704
705
706
707
708
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
709
        if not is_fp8_supported(config):
710
            pytest.skip("Model config does not support FP8")
711
712
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
713

Przemek Tredak's avatar
Przemek Tredak committed
714
715
716
717
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

718
719
720
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
721
        config.num_heads,
722
723
724
725
726
727
728
729
730
731
732
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
733
734
    )

735
736
737
738
739
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
740
741
        margin=0,
        fp8_format=recipe.Format.E4M3,
742
743
744
745
746
747
748
749
750
751
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
752
753
754
755


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
756
@pytest.mark.parametrize("model", ["small"])
757
@pytest.mark.parametrize("skip_wgrad", all_boolean)
758
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
759
760
    config = model_configs[model]

761
    if fp8_recipe is not None:
762
        if not is_fp8_supported(config):
763
            pytest.skip("Model config does not support FP8")
764
765
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
766

Przemek Tredak's avatar
Przemek Tredak committed
767
768
769
770
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

771
772
773
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
774
        config.num_heads,
775
776
777
778
779
780
781
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
782
783
    )

784
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
785
786
787
788


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
789
@pytest.mark.parametrize("model", ["small"])
790
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
791
792
    config = model_configs[model]

793
    if fp8_recipe is not None:
794
        if not is_fp8_supported(config):
795
            pytest.skip("Model config does not support FP8")
796
797
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
798

Przemek Tredak's avatar
Przemek Tredak committed
799
800
801
802
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

803
804
805
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
806
        config.num_heads,
807
808
809
810
811
812
813
814
815
816
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
817
818
    )

819
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
820
821
822
823


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
824
@pytest.mark.parametrize("model", ["small"])
825
@pytest.mark.parametrize("skip_wgrad", all_boolean)
826
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
827
828
    config = model_configs[model]

829
    if fp8_recipe is not None:
830
        if not is_fp8_supported(config):
831
            pytest.skip("Model config does not support FP8")
832
833
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
834

Przemek Tredak's avatar
Przemek Tredak committed
835
836
837
838
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

839
840
841
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
842
        config.num_heads,
843
844
845
846
847
848
849
850
851
852
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
853
854
    )

855
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
856
857
858
859


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
860
@pytest.mark.parametrize("model", ["small"])
861
@pytest.mark.parametrize("skip_wgrad", all_boolean)
862
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
863
864
    config = model_configs[model]

865
    if fp8_recipe is not None:
866
        if not is_fp8_supported(config):
867
            pytest.skip("Model config does not support FP8")
868
869
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
870

871
872
873
874
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

875
876
877
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
878
        config.num_heads,
879
880
881
882
883
884
885
886
887
888
889
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
890
891
    )

892
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
893
894


895
def test_model_multiple_cast():
896
897
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
898
899
900
901
902
903
904
905
906

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
907
908
909
910
911
912


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
913
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
914
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
915
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
916

917
    _ = general_gemm(A=weight, B=inp)
918
919
920
921
922
923
924
925
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
926
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
927

928
929
930
931
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
932
933
934

    outp_type = datatype

935
936
937
938
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
939
940
        weight_fp8,
        inp_fp8,
941
        outp_type,
942
943
944
        bias=None,
        use_split_accumulator=False,
    )
945
    torch.cuda.synchronize()
946
947


948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
979
980
981
def test_quantized_model_init_high_precision_init_val():
    """Test quantized_model_init with preserve_high_precision_init_val=True"""
    with quantized_model_init(preserve_high_precision_init_val=True):
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1041
1042


1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
1054
    with autocast():
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1082
1083
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
1108
        with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe):
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
1143
        with autocast(enabled=with_quantization, recipe=quantization_recipe):
1144
1145
            y = module(x, **kwargs)
    check_weights()