test_sanity.py 39.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
yuguo's avatar
yuguo committed
10
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
11

12
import transformer_engine.pytorch
13
14
15
16
17
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
21
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
24
25
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
26
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
27
28
    LayerNormMLP,
    TransformerLayer,
29
30
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
31
32
)
from transformer_engine.common import recipe
33
import transformer_engine_torch as tex
34
from transformer_engine.pytorch.cpp_extensions import general_gemm
35
from transformer_engine.pytorch.module.base import get_workspace
36
37
38
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
39
40
    Float8Quantizer,
    Float8Tensor,
41
)
42
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
43
from transformer_engine.pytorch.tensor.utils import replace_raw_data
44
from transformer_engine.pytorch.distributed import checkpoint
45
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
46

47
# Only run FP8 tests on supported devices.
48
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
49
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
50
51
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

72

73
74
75
76
77
78
79
80
81
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
82

83

84

Przemek Tredak's avatar
Przemek Tredak committed
85
model_configs = {
86
87
88
89
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
90
91
}

92
93
94
95
96
97
98
99
100
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
101

102
param_types = [torch.float32, torch.float16]
103
if is_bf16_compatible():  # bf16 requires sm_80 or higher
104
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
105

106
all_boolean = [True, False]
107
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
108

109
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
110
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
111

112

schetlur-nv's avatar
schetlur-nv committed
113
114
115
116
117
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


118
119
120
121
122
123
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


124
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
125
    te_inp_hidden_states = torch.randn(
126
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
127
128
129
130
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
131
    te_inp_hidden_states.retain_grad()
132
133
    te_inp_attn_mask = torch.randint(
        2,
134
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
135
136
137
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
138
139
140
141

    if skip_wgrad:
        _disable_wgrads(block)

142
143
144
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
145
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
146
147
148
149
150
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

151
    assert te_out.dtype == dtype, "AMP wrong output type."
152
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
153
154
155
156
157
158
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


159
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
160
    te_inp_hidden_states = torch.randn(
161
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
162
163
164
165
166
167
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
168
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
169
170
171
        dtype=torch.bool,
        device="cuda",
    )
172
173
174
175
176
177
178
179
180
181
182
183

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
184
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
185
186
187
188
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

189
    failed_grads = []
190
191
192
193
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
194
195
196
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
197

Przemek Tredak's avatar
Przemek Tredak committed
198

199
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
200
    te_inp_hidden_states = torch.randn(
201
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
202
203
204
205
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
206
207
208
209
210

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
211
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
212
213
214
215
216
217
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


218
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
219
    te_inp_hidden_states = torch.randn(
220
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
221
222
223
224
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
225

226
227
    te_inp_attn_mask = torch.randint(
        2,
228
        (config.batch_size, 1, 1, config.max_seqlen_q),
229
230
231
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
232
233
234
235

    if skip_wgrad:
        _disable_wgrads(block)

236
237
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
238
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
239
240
241
242
243
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


244
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
245
    te_inp_hidden_states = torch.randn(
246
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
247
248
249
250
251
252
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
253
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
254
255
256
257
258
259
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
260
        (config.batch_size, 1, 1, config.max_seqlen_kv),
261
262
263
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
264
265
266
267

    if skip_wgrad:
        _disable_wgrads(block)

268
269
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
270
        te_out = block(
271
272
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
273
274
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
275
276
277
278
279
280
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


281
282
283
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
284
285
286
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
287
    te_inp = torch.randn(
288
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
289
290
291
292
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
293
294
295
296

    if skip_wgrad:
        _disable_wgrads(block)

297
298
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
299
300
301
302
303
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
304
305
306
307
308
309
310
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


311
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
312
313
314
315
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
316
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
317
318
319
        device="cuda",
        requires_grad=True,
    )
320
321
322
323
324
325
326
327
328
329
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
330
    assert te_inp.grad is not None, "Gradient should not be empty"
331
332
333
334
335
336
337
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
338
@pytest.mark.parametrize("model", ["small", "weird"])
339
340
341
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
342
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
343
344
345
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

346
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
347
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
348
349


Przemek Tredak's avatar
Przemek Tredak committed
350
351
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
352
@pytest.mark.parametrize("model", ["small", "weird"])
353
354
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
355
@pytest.mark.parametrize("skip_dgrad", all_boolean)
356
@pytest.mark.parametrize("normalization", all_normalizations)
357
@pytest.mark.parametrize("microbatching", all_boolean)
358
def test_sanity_layernorm_linear(
359
360
361
362
363
364
365
366
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
367
):
368
369
370
    config = model_configs[model]

    if fp8_recipe is not None:
371
        if not is_fp8_supported(config):
372
            pytest.skip("Model config does not support FP8")
373
374
375
376
377
378
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
379

Przemek Tredak's avatar
Przemek Tredak committed
380
381
382
    sigma = 0.023
    init_method = init_method_normal(sigma)

383
384
385
386
387
388
389
390
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
391
    )
392
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
393
394
395
396


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
397
@pytest.mark.parametrize("model", ["small", "weird"])
398
@pytest.mark.parametrize("skip_wgrad", all_boolean)
399
@pytest.mark.parametrize("skip_dgrad", all_boolean)
400
401
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
402
403
    config = model_configs[model]

404
    if fp8_recipe is not None:
405
406
407
408
409
410
411
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
412
413
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
414
415
416
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

417
418
419
420
421
422
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
423
    )
424
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
425
426


427
428
429
430
431
432
433
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
434
435
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
436
437
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
438
    num_tokens = bs * config.max_seqlen_q
439
440

    if fp8_recipe is not None:
441
442
443
444
445
446
447
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
448
449
450
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
451
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
452
453
454
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
455
456
457
458
459
460
461
462
463
464
465

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


466
467
468
469
470
471
472
473
474
475
476
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
477
478
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
479
480
481
482
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
483
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
484
485

    if fp8_recipe is not None:
486
487
488
489
490
491
492
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
493
494
495
496
497
498
499
500
501
502
503
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
504
    m_splits = [bs * config.max_seqlen_q] * num_gemms
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
519
520
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
521
@pytest.mark.parametrize("model", ["small", "weird"])
522
523
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
524
@pytest.mark.parametrize("skip_dgrad", all_boolean)
525
@pytest.mark.parametrize("activation", all_activations)
526
@pytest.mark.parametrize("normalization", all_normalizations)
527
@pytest.mark.parametrize("microbatching", all_boolean)
528
def test_sanity_layernorm_mlp(
529
530
531
532
533
534
535
536
537
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
538
):
539
540
541
    config = model_configs[model]

    if fp8_recipe is not None:
542
543
544
545
546
547
548
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
549
            pytest.skip("Model config does not support FP8")
550

Przemek Tredak's avatar
Przemek Tredak committed
551
552
553
554
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

555
556
557
558
559
560
561
562
563
564
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
565
    )
566
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
567
568
569
570


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
571
@pytest.mark.parametrize("model", ["small"])
572
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
573
@pytest.mark.parametrize("bias", all_boolean)
574
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
575
@pytest.mark.parametrize("normalization", all_normalizations)
576
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
577
578
579
580
581
582
583
584
585
586
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
587
588
589
    config = model_configs[model]

    if fp8_recipe is not None:
590
591
592
593
594
595
596
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
597
            pytest.skip("Model config does not support FP8")
598

Przemek Tredak's avatar
Przemek Tredak committed
599
600
601
602
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

603
604
605
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
606
        config.num_heads,
607
608
609
610
611
612
613
614
615
616
617
618
619
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
620
621
    )

622
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
623
624
625
626
627
628


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
629
630
            margin=0,
            fp8_format=recipe.Format.E4M3,
631
632
633
634
635
636
637
638
639
640
641
642
643
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
644
645
646
647


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
648
@pytest.mark.parametrize("model", ["small"])
649
@pytest.mark.parametrize("skip_wgrad", all_boolean)
650
@pytest.mark.parametrize("normalization", all_normalizations)
651
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
652
653
654
655
656
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
657
        if not is_fp8_supported(config):
658
            pytest.skip("Model config does not support FP8")
659

Przemek Tredak's avatar
Przemek Tredak committed
660
661
662
663
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

664
665
666
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
667
        config.num_heads,
668
669
670
671
672
673
674
675
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
676
        self_attn_mask_type="causal",
677
678
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
679
680
    )

681
682
683
684
685
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
686
687
        margin=0,
        fp8_format=recipe.Format.E4M3,
688
689
690
691
692
693
694
695
696
697
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
698
699
700
701


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
702
@pytest.mark.parametrize("model", ["small"])
703
@pytest.mark.parametrize("skip_wgrad", all_boolean)
704
@pytest.mark.parametrize("normalization", all_normalizations)
705
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
706
707
708
709
710
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
711
        if not is_fp8_supported(config):
712
            pytest.skip("Model config does not support FP8")
713

Przemek Tredak's avatar
Przemek Tredak committed
714
715
716
717
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

718
719
720
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
721
        config.num_heads,
722
723
724
725
726
727
728
729
730
731
732
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
733
734
    )

735
736
737
738
739
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
740
741
        margin=0,
        fp8_format=recipe.Format.E4M3,
742
743
744
745
746
747
748
749
750
751
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
752
753
754
755


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
756
@pytest.mark.parametrize("model", ["small"])
757
@pytest.mark.parametrize("skip_wgrad", all_boolean)
758
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
759
760
    config = model_configs[model]

761
    if fp8_recipe is not None:
762
763
764
765
766
767
768
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
769
770
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
771
772
773
774
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

775
776
777
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
778
        config.num_heads,
779
780
781
782
783
784
785
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
786
787
    )

788
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
789
790
791
792


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
793
@pytest.mark.parametrize("model", ["small"])
794
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
795
796
    config = model_configs[model]

797
    if fp8_recipe is not None:
798
799
800
801
802
803
804
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
805
806
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
807
808
809
810
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

811
812
813
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
814
        config.num_heads,
815
816
817
818
819
820
821
822
823
824
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
825
826
    )

827
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
828
829
830
831


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
832
@pytest.mark.parametrize("model", ["small"])
833
@pytest.mark.parametrize("skip_wgrad", all_boolean)
834
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
835
836
    config = model_configs[model]

837
    if fp8_recipe is not None:
838
839
840
841
842
843
844
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
845
846
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
847
848
849
850
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

851
852
853
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
854
        config.num_heads,
855
856
857
858
859
860
861
862
863
864
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
865
866
    )

867
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
868
869
870
871


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
872
@pytest.mark.parametrize("model", ["small"])
873
@pytest.mark.parametrize("skip_wgrad", all_boolean)
874
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
875
876
    config = model_configs[model]

877
    if fp8_recipe is not None:
878
879
880
881
882
883
884
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
885
886
            pytest.skip("Model config does not support FP8")

887
888
889
890
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

891
892
893
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
894
        config.num_heads,
895
896
897
898
899
900
901
902
903
904
905
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
906
907
    )

908
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
909
910


911
def test_model_multiple_cast():
912
913
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
914
915
916
917
918
919
920
921
922

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
923
924
925
926
927
928


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
929
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
930
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
931
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
932

933
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
934
935
936
937
938
939
940
941
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
942
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
943

944
945
946
947
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
948
949
950

    outp_type = datatype

951
952
953
954
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
955
956
957
        weight_fp8,
        inp_fp8,
        get_workspace(),
958
        outp_type,
959
960
961
        bias=None,
        use_split_accumulator=False,
    )
962
    torch.cuda.synchronize()
963
964


965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1058
1059


1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
    with fp8_autocast():
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1099
1100
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()