test_sanity.py 43.1 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
from typing import Optional, List
6

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
Przemek Tredak's avatar
Przemek Tredak committed
10

11
12
13
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
14
15
16
17
18
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
19
20
    autocast,
    quantized_model_init,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
    LayerNormLinear,
    Linear,
23
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
24
25
    LayerNormMLP,
    TransformerLayer,
26
27
    RMSNorm,
    LayerNorm,
28
29
30
31
32
33
34
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
    Float8Tensor,
    MXFP8Tensor,
    checkpoint,
    QuantizedTensor,
    is_bf16_available,
Przemek Tredak's avatar
Przemek Tredak committed
35
36
)
from transformer_engine.common import recipe
37
import transformer_engine_torch as tex
38
from transformer_engine.pytorch.cpp_extensions import general_gemm
39
from transformer_engine.pytorch.tensor.utils import replace_raw_data
40
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
41

42
# Only run FP8 tests on supported devices.
43
44
45
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

67

68
69
70
71
72
73
74
75
76
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
77

78

Przemek Tredak's avatar
Przemek Tredak committed
79
model_configs = {
80
81
82
83
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
84
85
}

86
87
88
89
90
91
92
93
94

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


95
96
97
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
98
    fp8_recipes.append(nvfp4_vanilla())  # TODO: fix check for this
99
100
101
102
103
104
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
105

106
param_types = [torch.float32, torch.float16]
107
if is_bf16_available():  # bf16 requires sm_80 or higher
108
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
109

110
all_boolean = [True, False]
111
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
112

113
114
115
all_activations = [
    "gelu",
    "geglu",
Kim, Jin (Jay@SKT)'s avatar
Kim, Jin (Jay@SKT) committed
116
    "glu",
117
118
119
120
121
122
123
124
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
125
    "clamped_swiglu",
126
]
127
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
128

129

schetlur-nv's avatar
schetlur-nv committed
130
131
132
133
134
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


135
136
137
138
139
140
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
    """
    Verify that tensors are stored in contiguous memory.

    Args:
        tensors: List or iterable of tensors to check
        num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
        tensor_name: Name to use in error messages
    """
    tensor_list = list(tensors)
    if len(tensor_list) < 2:
        return  # Nothing to check

    for i in range(1, len(tensor_list)):
        prev_tensor = tensor_list[i - 1]
        curr_tensor = tensor_list[i]

        # Calculate expected offset based on previous tensor size
        prev_numel = prev_tensor.numel()
        expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()

        # Verify current tensor's data pointer is correctly offset
        expected_ptr = prev_tensor.data_ptr() + expected_offset
        actual_ptr = curr_tensor.data_ptr()

        assert (
            actual_ptr == expected_ptr
        ), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"


def check_grouped_tensor_pointers(
    weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
    """
    Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
    TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
    """

    num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1

    # Check data.
    if hasattr(weights[0], "_data") and weights[0]._data is not None:
        data_tensors = [w._data for w in weights]
        check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")

    # Check transpose.
    if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
        transpose_tensors = [w._transpose for w in weights]
        check_grouped_tensor_pointers_helper(
            transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
        )

    # Check scale_inv.
    if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
        scale_inv_tensors = [w._scale_inv for w in weights]
        check_grouped_tensor_pointers_helper(
            scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
        )

    # Check rowwise scale_inv.
    if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
        scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
        check_grouped_tensor_pointers_helper(
            scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
        )

    # Check columnwise scale_inv.
    if (
        hasattr(weights[0], "_columnwise_scale_inv")
        and weights[0]._columnwise_scale_inv is not None
    ):
        columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
        check_grouped_tensor_pointers_helper(
            columnwise_scale_inv_tensors,
            num_elems_in_byte=1,
            tensor_name="columnwise scale_inv",
        )

    # Check rowwise amax.
    if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
        rowwise_amax_tensors = [w._rowwise_amax for w in weights]
        check_grouped_tensor_pointers_helper(
            rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
        )

    # Check columnwise amax.
    if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
        columnwise_amax_tensors = [w._columnwise_amax for w in weights]
        check_grouped_tensor_pointers_helper(
            columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
        )

    # Check rowwise data.
    if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
        rowwise_data_tensors = [w._rowwise_data for w in weights]
        check_grouped_tensor_pointers_helper(
            rowwise_data_tensors,
            num_elems_in_byte=num_elems_in_a_data_byte,
            tensor_name="rowwise data",
        )

    # Check columnwise data.
    if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
        columnwise_data_tensors = [w._columnwise_data for w in weights]
        check_grouped_tensor_pointers_helper(
            columnwise_data_tensors,
            num_elems_in_byte=num_elems_in_a_data_byte,
            tensor_name="columnwise data",
        )


252
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
253
    te_inp_hidden_states = torch.randn(
254
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
255
256
257
258
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
259
    te_inp_hidden_states.retain_grad()
260
261
    te_inp_attn_mask = torch.randint(
        2,
262
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
263
264
265
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
266
267
268
269

    if skip_wgrad:
        _disable_wgrads(block)

270
271
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
272
        with autocast(enabled=use_fp8, recipe=fp8_recipe):
273
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
274
275
276
277
278
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

279
    assert te_out.dtype == dtype, "AMP wrong output type."
280
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
281
282
283
284
285
286
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


287
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
288
    te_inp_hidden_states = torch.randn(
289
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
290
291
292
293
294
295
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
296
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
297
298
299
        dtype=torch.bool,
        device="cuda",
    )
300
301
302
303
304
305
306
307
308
309
310

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
311
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
312
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
313
314
315
316
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

317
    failed_grads = []
318
319
320
321
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
322
323
324
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
325

Przemek Tredak's avatar
Przemek Tredak committed
326

327
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
328
    te_inp_hidden_states = torch.randn(
329
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
330
331
332
333
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
334
335
336
337
338

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
339
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
340
341
342
343
344
345
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


346
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
347
    te_inp_hidden_states = torch.randn(
348
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
349
350
351
352
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
353

354
355
    te_inp_attn_mask = torch.randint(
        2,
356
        (config.batch_size, 1, 1, config.max_seqlen_q),
357
358
359
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
360
361
362
363

    if skip_wgrad:
        _disable_wgrads(block)

364
    use_fp8 = fp8_recipe is not None
365
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
366
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
367
368
369
370
371
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


372
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
373
    te_inp_hidden_states = torch.randn(
374
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
375
376
377
378
379
380
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
381
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
382
383
384
385
386
387
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
388
        (config.batch_size, 1, 1, config.max_seqlen_kv),
389
390
391
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
392
393
394
395

    if skip_wgrad:
        _disable_wgrads(block)

396
    use_fp8 = fp8_recipe is not None
397
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
398
        te_out = block(
399
400
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
401
402
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
403
404
405
406
407
408
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


409
410
411
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
412
413
414
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
415
    te_inp = torch.randn(
416
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
417
418
419
420
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
421
422
423
424

    if skip_wgrad:
        _disable_wgrads(block)

425
    use_fp8 = fp8_recipe is not None
426
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
427
428
429
430
431
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
432
433
434
435
436
437
438
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


439
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
440
441
442
443
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
444
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
445
446
447
        device="cuda",
        requires_grad=True,
    )
448
449
450
451
452
453
454
455
456
457
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
458
    assert te_inp.grad is not None, "Gradient should not be empty"
459
460
461
462
463
464
465
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
466
@pytest.mark.parametrize("model", ["small", "weird"])
467
468
469
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
470
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
471
472
473
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

474
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
475
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
476
477


Przemek Tredak's avatar
Przemek Tredak committed
478
479
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
480
@pytest.mark.parametrize("model", ["small", "weird"])
481
482
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
483
@pytest.mark.parametrize("skip_dgrad", all_boolean)
484
@pytest.mark.parametrize("normalization", all_normalizations)
485
@pytest.mark.parametrize("microbatching", all_boolean)
486
def test_sanity_layernorm_linear(
487
488
489
490
491
492
493
494
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
495
):
496
497
498
    config = model_configs[model]

    if fp8_recipe is not None:
499
        if not is_fp8_supported(config):
500
            pytest.skip("Model config does not support FP8")
501
502
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
503

Przemek Tredak's avatar
Przemek Tredak committed
504
505
506
    sigma = 0.023
    init_method = init_method_normal(sigma)

507
508
509
510
511
512
513
514
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
515
    )
516
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
517
518
519
520


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
521
@pytest.mark.parametrize("model", ["small", "weird"])
522
@pytest.mark.parametrize("skip_wgrad", all_boolean)
523
@pytest.mark.parametrize("skip_dgrad", all_boolean)
524
525
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
526
527
    config = model_configs[model]

528
    if fp8_recipe is not None:
529
        if not is_fp8_supported(config):
530
            pytest.skip("Model config does not support FP8")
531
532
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
533

Przemek Tredak's avatar
Przemek Tredak committed
534
535
536
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

537
538
539
540
541
542
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
543
    )
544
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
545
546


547
548
549
550
551
552
553
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
554
555
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
556
557
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
558
    num_tokens = bs * config.max_seqlen_q
559
560

    if fp8_recipe is not None:
561
        if not is_fp8_supported(config):
562
            pytest.skip("Model config does not support FP8")
563
564
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
565
566

    use_fp8 = fp8_recipe is not None
567
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
568
569
570
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
571
572
573
574

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
575
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
576
577
578
579
580
581
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


582
583
584
585
586
587
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
588
@pytest.mark.parametrize("single_param", all_boolean)
589
590
591
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
592
593
594
595
596
597
598
599
600
    dtype,
    bs,
    model,
    fp8_recipe,
    fp8_model_params,
    use_bias,
    single_param,
    num_gemms,
    empty_split,
601
):
602
603
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
604
605
606
607
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
608
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
609

610
611
612
    if single_param:
        os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"

613
    if fp8_recipe is not None:
614
        if not is_fp8_supported(config):
615
            pytest.skip("Model config does not support FP8")
616
617
        if fp8_recipe.nvfp4():
            pytest.skip("NVFP4 not supported for grouped linear")
618
619

    use_fp8 = fp8_recipe is not None
620
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
621
        te_grouped_linear = GroupedLinear(
622
623
624
625
626
            num_gemms,
            config.hidden_size,
            ffn_hidden_size,
            bias=use_bias,
            params_dtype=dtype,
627
628
        ).cuda()

629
630
631
    # Verify that weights are stored in contiguous GroupedTensor storage.
    weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
    if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
632
633
        if single_param:
            check_grouped_tensor_pointers(weights, fp8_recipe)
634

635
636
637
    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
638
    m_splits = [bs * config.max_seqlen_q] * num_gemms
639
640
641
642
643
644
645
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

646
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
647
648
649
650
651
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)

652
653
654
    if single_param:
        del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]

655

Przemek Tredak's avatar
Przemek Tredak committed
656
657
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
658
@pytest.mark.parametrize("model", ["small", "weird"])
659
660
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
661
@pytest.mark.parametrize("skip_dgrad", all_boolean)
662
@pytest.mark.parametrize("activation", all_activations)
663
@pytest.mark.parametrize("normalization", all_normalizations)
664
@pytest.mark.parametrize("microbatching", all_boolean)
665
@pytest.mark.parametrize("checkpoint", all_boolean)
666
def test_sanity_layernorm_mlp(
667
668
669
670
671
672
673
674
675
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
676
    checkpoint,
677
):
678
679
680
    config = model_configs[model]

    if fp8_recipe is not None:
681
        if not is_fp8_supported(config):
682
            pytest.skip("Model config does not support FP8")
683
684
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
685

Przemek Tredak's avatar
Przemek Tredak committed
686
687
688
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
689
    activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
690
691
692
693
694
695
696
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
697
        activation_params=activation_params,
698
699
700
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
701
        checkpoint=checkpoint,
Przemek Tredak's avatar
Przemek Tredak committed
702
    )
703
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
704
705
706
707


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
708
@pytest.mark.parametrize("model", ["small"])
709
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
710
@pytest.mark.parametrize("bias", all_boolean)
711
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
712
@pytest.mark.parametrize("normalization", all_normalizations)
713
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
714
715
716
717
718
719
720
721
722
723
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
724
725
726
    config = model_configs[model]

    if fp8_recipe is not None:
727
        if not is_fp8_supported(config):
728
            pytest.skip("Model config does not support FP8")
729
730
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
731

Przemek Tredak's avatar
Przemek Tredak committed
732
733
734
735
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

736
737
738
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
739
        config.num_heads,
740
741
742
743
744
745
746
747
748
749
750
751
752
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
753
754
    )

755
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
756
757
758
759
760
761


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
762
763
            margin=0,
            fp8_format=recipe.Format.E4M3,
764
765
766
767
768
769
770
771
772
773
774
775
776
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
777
778
779
780


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
781
@pytest.mark.parametrize("model", ["small"])
782
@pytest.mark.parametrize("skip_wgrad", all_boolean)
783
@pytest.mark.parametrize("normalization", all_normalizations)
784
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
785
786
787
788
789
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
790
        if not is_fp8_supported(config):
791
            pytest.skip("Model config does not support FP8")
792
793
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
794

Przemek Tredak's avatar
Przemek Tredak committed
795
796
797
798
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

799
800
801
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
802
        config.num_heads,
803
804
805
806
807
808
809
810
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
811
        self_attn_mask_type="causal",
812
813
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
814
815
    )

816
817
818
819
820
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
821
822
        margin=0,
        fp8_format=recipe.Format.E4M3,
823
824
825
826
827
828
829
830
831
832
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
833
834
835
836


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
837
@pytest.mark.parametrize("model", ["small"])
838
@pytest.mark.parametrize("skip_wgrad", all_boolean)
839
@pytest.mark.parametrize("normalization", all_normalizations)
840
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
841
842
843
844
845
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
846
        if not is_fp8_supported(config):
847
            pytest.skip("Model config does not support FP8")
848
849
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
850

Przemek Tredak's avatar
Przemek Tredak committed
851
852
853
854
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

855
856
857
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
858
        config.num_heads,
859
860
861
862
863
864
865
866
867
868
869
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
870
871
    )

872
873
874
875
876
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
877
878
        margin=0,
        fp8_format=recipe.Format.E4M3,
879
880
881
882
883
884
885
886
887
888
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
889
890
891
892


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
893
@pytest.mark.parametrize("model", ["small"])
894
@pytest.mark.parametrize("skip_wgrad", all_boolean)
895
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
896
897
    config = model_configs[model]

898
    if fp8_recipe is not None:
899
        if not is_fp8_supported(config):
900
            pytest.skip("Model config does not support FP8")
901
902
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
903

Przemek Tredak's avatar
Przemek Tredak committed
904
905
906
907
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

908
909
910
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
911
        config.num_heads,
912
913
914
915
916
917
918
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
919
920
    )

921
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
922
923
924
925


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
926
@pytest.mark.parametrize("model", ["small"])
927
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
928
929
    config = model_configs[model]

930
    if fp8_recipe is not None:
931
        if not is_fp8_supported(config):
932
            pytest.skip("Model config does not support FP8")
933
934
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
935

Przemek Tredak's avatar
Przemek Tredak committed
936
937
938
939
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

940
941
942
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
943
        config.num_heads,
944
945
946
947
948
949
950
951
952
953
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
954
955
    )

956
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
957
958
959
960


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
961
@pytest.mark.parametrize("model", ["small"])
962
@pytest.mark.parametrize("skip_wgrad", all_boolean)
963
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
964
965
    config = model_configs[model]

966
    if fp8_recipe is not None:
967
        if not is_fp8_supported(config):
968
            pytest.skip("Model config does not support FP8")
969
970
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
971

Przemek Tredak's avatar
Przemek Tredak committed
972
973
974
975
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

976
977
978
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
979
        config.num_heads,
980
981
982
983
984
985
986
987
988
989
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
990
991
    )

992
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
993
994
995
996


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
997
@pytest.mark.parametrize("model", ["small"])
998
@pytest.mark.parametrize("skip_wgrad", all_boolean)
999
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
1000
1001
    config = model_configs[model]

1002
    if fp8_recipe is not None:
1003
        if not is_fp8_supported(config):
1004
            pytest.skip("Model config does not support FP8")
1005
1006
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
1007

1008
1009
1010
1011
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1012
1013
1014
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
1015
        config.num_heads,
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
1027
1028
    )

1029
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
1030
1031


1032
def test_model_multiple_cast():
1033
1034
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
1035
1036
1037
1038
1039
1040
1041
1042
1043

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
1044
1045
1046
1047
1048
1049


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
1050
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
1051
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
1052
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
1053

1054
    _ = general_gemm(A=weight, B=inp)
1055
1056
1057
1058
1059
1060
1061
1062
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
1063
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
1064

1065
1066
1067
1068
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
1069
1070
1071

    outp_type = datatype

1072
1073
1074
1075
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
1076
1077
        weight_fp8,
        inp_fp8,
1078
        outp_type,
1079
1080
1081
        bias=None,
        use_split_accumulator=False,
    )
1082
    torch.cuda.synchronize()
1083
1084


1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

1096
1097
1098
1099
1100
1101
1102
    attrs_to_check = [
        "_quantizer",
        "_fp8_dtype",
        "_scale_inv",
        "_transpose",
        "_transpose_invalid",
    ]
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1122
1123
1124
def test_quantized_model_init_high_precision_init_val():
    """Test quantized_model_init with preserve_high_precision_init_val=True"""
    with quantized_model_init(preserve_high_precision_init_val=True):
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1184
1185


1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
1197
    with autocast():
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1225
1226
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
1251
        with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe):
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
1286
        with autocast(enabled=with_quantization, recipe=quantization_recipe):
1287
1288
            y = module(x, **kwargs)
    check_weights()