test_sanity.py 45.6 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
from typing import Optional, List
6

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
yuguo's avatar
yuguo committed
10
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
11

12
13
14
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
15
16
17
18
19
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
from transformer_engine.pytorch import (
20
21
    autocast,
    quantized_model_init,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
    LayerNormLinear,
    Linear,
24
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
25
26
    LayerNormMLP,
    TransformerLayer,
27
28
    RMSNorm,
    LayerNorm,
29
30
31
32
33
34
35
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
    Float8Tensor,
    MXFP8Tensor,
    checkpoint,
    QuantizedTensor,
    is_bf16_available,
Przemek Tredak's avatar
Przemek Tredak committed
36
37
)
from transformer_engine.common import recipe
38
import transformer_engine_torch as tex
39
from transformer_engine.pytorch.cpp_extensions import general_gemm
40
from transformer_engine.pytorch.tensor.utils import replace_raw_data
41
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
42

43
# Only run FP8 tests on supported devices.
44
45
46
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

68

69
70
71
72
73
74
75
76
77
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
78

79

Przemek Tredak's avatar
Przemek Tredak committed
80
model_configs = {
81
82
83
84
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
85
86
}

87
88
89
90
91
92
93
94
95

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


96
97
98
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
99
    fp8_recipes.append(nvfp4_vanilla())  # TODO: fix check for this
100
101
102
103
104
105
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
106

107
param_types = [torch.float32, torch.float16]
108
if is_bf16_available():  # bf16 requires sm_80 or higher
109
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
110

111
all_boolean = [True, False]
112
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
113

114
115
116
all_activations = [
    "gelu",
    "geglu",
Kim, Jin (Jay@SKT)'s avatar
Kim, Jin (Jay@SKT) committed
117
    "glu",
118
119
120
121
122
123
124
125
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
126
    "clamped_swiglu",
127
]
128
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
129

130

schetlur-nv's avatar
schetlur-nv committed
131
132
133
134
135
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


136
137
138
139
140
141
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
    """
    Verify that tensors are stored in contiguous memory.

    Args:
        tensors: List or iterable of tensors to check
        num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
        tensor_name: Name to use in error messages
    """
    tensor_list = list(tensors)
    if len(tensor_list) < 2:
        return  # Nothing to check

    for i in range(1, len(tensor_list)):
        prev_tensor = tensor_list[i - 1]
        curr_tensor = tensor_list[i]

        # Calculate expected offset based on previous tensor size
        prev_numel = prev_tensor.numel()
        expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()

        # Verify current tensor's data pointer is correctly offset
        expected_ptr = prev_tensor.data_ptr() + expected_offset
        actual_ptr = curr_tensor.data_ptr()

        assert (
            actual_ptr == expected_ptr
        ), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"


def check_grouped_tensor_pointers(
    weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
    """
    Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
    TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
    """

    num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1

    # Check data.
    if hasattr(weights[0], "_data") and weights[0]._data is not None:
        data_tensors = [w._data for w in weights]
        check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")

    # Check transpose.
    if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
        transpose_tensors = [w._transpose for w in weights]
        check_grouped_tensor_pointers_helper(
            transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
        )

    # Check scale_inv.
    if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
        scale_inv_tensors = [w._scale_inv for w in weights]
        check_grouped_tensor_pointers_helper(
            scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
        )

    # Check rowwise scale_inv.
    if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
        scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
        check_grouped_tensor_pointers_helper(
            scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
        )

    # Check columnwise scale_inv.
    if (
        hasattr(weights[0], "_columnwise_scale_inv")
        and weights[0]._columnwise_scale_inv is not None
    ):
        columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
        check_grouped_tensor_pointers_helper(
            columnwise_scale_inv_tensors,
            num_elems_in_byte=1,
            tensor_name="columnwise scale_inv",
        )

    # Check rowwise amax.
    if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
        rowwise_amax_tensors = [w._rowwise_amax for w in weights]
        check_grouped_tensor_pointers_helper(
            rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
        )

    # Check columnwise amax.
    if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
        columnwise_amax_tensors = [w._columnwise_amax for w in weights]
        check_grouped_tensor_pointers_helper(
            columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
        )

    # Check rowwise data.
    if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
        rowwise_data_tensors = [w._rowwise_data for w in weights]
        check_grouped_tensor_pointers_helper(
            rowwise_data_tensors,
            num_elems_in_byte=num_elems_in_a_data_byte,
            tensor_name="rowwise data",
        )

    # Check columnwise data.
    if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
        columnwise_data_tensors = [w._columnwise_data for w in weights]
        check_grouped_tensor_pointers_helper(
            columnwise_data_tensors,
            num_elems_in_byte=num_elems_in_a_data_byte,
            tensor_name="columnwise data",
        )


253
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
254
    te_inp_hidden_states = torch.randn(
255
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
256
257
258
259
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
260
    te_inp_hidden_states.retain_grad()
261
262
    te_inp_attn_mask = torch.randint(
        2,
263
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
264
265
266
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
267
268
269
270

    if skip_wgrad:
        _disable_wgrads(block)

271
272
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
273
        with autocast(enabled=use_fp8, recipe=fp8_recipe):
274
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
275
276
277
278
279
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

280
    assert te_out.dtype == dtype, "AMP wrong output type."
281
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
282
283
284
285
286
287
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


288
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
289
    te_inp_hidden_states = torch.randn(
290
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
291
292
293
294
295
296
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
297
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
298
299
300
        dtype=torch.bool,
        device="cuda",
    )
301
302
303
304
305
306
307
308
309
310
311

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
312
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
313
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
314
315
316
317
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

318
    failed_grads = []
319
320
321
322
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
323
324
325
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
326

Przemek Tredak's avatar
Przemek Tredak committed
327

328
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
329
    te_inp_hidden_states = torch.randn(
330
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
331
332
333
334
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
335
336
337
338
339

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
340
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
341
342
343
344
345
346
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


347
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
348
    te_inp_hidden_states = torch.randn(
349
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
350
351
352
353
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
354

355
356
    te_inp_attn_mask = torch.randint(
        2,
357
        (config.batch_size, 1, 1, config.max_seqlen_q),
358
359
360
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
361
362
363
364

    if skip_wgrad:
        _disable_wgrads(block)

365
    use_fp8 = fp8_recipe is not None
366
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
367
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
368
369
370
371
372
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


373
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
374
    te_inp_hidden_states = torch.randn(
375
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
376
377
378
379
380
381
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
382
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
383
384
385
386
387
388
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
389
        (config.batch_size, 1, 1, config.max_seqlen_kv),
390
391
392
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
393
394
395
396

    if skip_wgrad:
        _disable_wgrads(block)

397
    use_fp8 = fp8_recipe is not None
398
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
399
        te_out = block(
400
401
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
402
403
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
404
405
406
407
408
409
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


410
411
412
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
413
414
415
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
416
    te_inp = torch.randn(
417
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
418
419
420
421
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
422
423
424
425

    if skip_wgrad:
        _disable_wgrads(block)

426
    use_fp8 = fp8_recipe is not None
427
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
428
429
430
431
432
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
433
434
435
436
437
438
439
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


440
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
441
442
443
444
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
445
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
446
447
448
        device="cuda",
        requires_grad=True,
    )
449
450
451
452
453
454
455
456
457
458
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
459
    assert te_inp.grad is not None, "Gradient should not be empty"
460
461
462
463
464
465
466
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
467
@pytest.mark.parametrize("model", ["small", "weird"])
468
469
470
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
471
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
472
473
474
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

475
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
476
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
477
478


Przemek Tredak's avatar
Przemek Tredak committed
479
480
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
481
@pytest.mark.parametrize("model", ["small", "weird"])
482
483
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
484
@pytest.mark.parametrize("skip_dgrad", all_boolean)
485
@pytest.mark.parametrize("normalization", all_normalizations)
486
@pytest.mark.parametrize("microbatching", all_boolean)
487
def test_sanity_layernorm_linear(
488
489
490
491
492
493
494
495
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
496
):
497
498
499
    config = model_configs[model]

    if fp8_recipe is not None:
500
501
502
503
504
505
506
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
507
            pytest.skip("Model config does not support FP8")
508
509
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
510

Przemek Tredak's avatar
Przemek Tredak committed
511
512
513
    sigma = 0.023
    init_method = init_method_normal(sigma)

514
515
516
517
518
519
520
521
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
522
    )
523
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
524
525
526
527


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
528
@pytest.mark.parametrize("model", ["small", "weird"])
529
@pytest.mark.parametrize("skip_wgrad", all_boolean)
530
@pytest.mark.parametrize("skip_dgrad", all_boolean)
531
532
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
533
534
    config = model_configs[model]

535
    if fp8_recipe is not None:
536
        if not is_fp8_supported(config):
537
            pytest.skip("Model config does not support FP8")
538
539
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
540

Przemek Tredak's avatar
Przemek Tredak committed
541
542
543
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

544
545
546
547
548
549
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
550
    )
551
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
552
553


554
555
556
557
558
559
560
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
561
562
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
563
564
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
565
    num_tokens = bs * config.max_seqlen_q
566
567

    if fp8_recipe is not None:
568
569
570
571
572
573
574
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
575
            pytest.skip("Model config does not support FP8")
576
577
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
578
579

    use_fp8 = fp8_recipe is not None
580
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
581
582
583
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
584
585
586
587

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
588
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
589
590
591
592
593
594
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


595
596
597
598
599
600
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
601
@pytest.mark.parametrize("single_param", all_boolean)
602
603
604
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
605
606
607
608
609
610
611
612
613
    dtype,
    bs,
    model,
    fp8_recipe,
    fp8_model_params,
    use_bias,
    single_param,
    num_gemms,
    empty_split,
614
):
615
616
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
617
618
619
620
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
621
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
622

623
624
625
    if single_param:
        os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"

626
    if fp8_recipe is not None:
627
        if not is_fp8_supported(config):
628
            pytest.skip("Model config does not support FP8")
629
630
        if fp8_recipe.nvfp4():
            pytest.skip("NVFP4 not supported for grouped linear")
631
632

    use_fp8 = fp8_recipe is not None
633
    with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
634
        te_grouped_linear = GroupedLinear(
635
636
637
638
639
            num_gemms,
            config.hidden_size,
            ffn_hidden_size,
            bias=use_bias,
            params_dtype=dtype,
640
641
        ).cuda()

642
643
644
    # Verify that weights are stored in contiguous GroupedTensor storage.
    weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
    if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
645
646
        if single_param:
            check_grouped_tensor_pointers(weights, fp8_recipe)
647

648
649
650
    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
651
    m_splits = [bs * config.max_seqlen_q] * num_gemms
652
653
654
655
656
657
658
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

659
    with autocast(enabled=use_fp8, recipe=fp8_recipe):
660
661
662
663
664
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)

665
666
667
    if single_param:
        del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]

668

Przemek Tredak's avatar
Przemek Tredak committed
669
670
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
671
@pytest.mark.parametrize("model", ["small", "weird"])
672
673
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
674
@pytest.mark.parametrize("skip_dgrad", all_boolean)
675
@pytest.mark.parametrize("activation", all_activations)
676
@pytest.mark.parametrize("normalization", all_normalizations)
677
@pytest.mark.parametrize("microbatching", all_boolean)
678
@pytest.mark.parametrize("checkpoint", all_boolean)
679
def test_sanity_layernorm_mlp(
680
681
682
683
684
685
686
687
688
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
689
    checkpoint,
690
):
691
692
693
    config = model_configs[model]

    if fp8_recipe is not None:
694
695
696
697
698
699
700
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
701
            pytest.skip("Model config does not support FP8")
702
703
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
704

Przemek Tredak's avatar
Przemek Tredak committed
705
706
707
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
708
    activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
709
710
711
712
713
714
715
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
716
        activation_params=activation_params,
717
718
719
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
720
        checkpoint=checkpoint,
Przemek Tredak's avatar
Przemek Tredak committed
721
    )
722
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
723
724
725
726


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
727
@pytest.mark.parametrize("model", ["small"])
728
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
729
@pytest.mark.parametrize("bias", all_boolean)
730
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
731
@pytest.mark.parametrize("normalization", all_normalizations)
732
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
733
734
735
736
737
738
739
740
741
742
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
743
744
745
    config = model_configs[model]

    if fp8_recipe is not None:
746
747
748
749
750
751
752
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
753
            pytest.skip("Model config does not support FP8")
754
755
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
756

Przemek Tredak's avatar
Przemek Tredak committed
757
758
759
760
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

761
762
763
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
764
        config.num_heads,
765
766
767
768
769
770
771
772
773
774
775
776
777
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
778
779
    )

780
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
781
782
783
784
785
786


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
787
788
            margin=0,
            fp8_format=recipe.Format.E4M3,
789
790
791
792
793
794
795
796
797
798
799
800
801
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
802
803
804
805


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
806
@pytest.mark.parametrize("model", ["small"])
807
@pytest.mark.parametrize("skip_wgrad", all_boolean)
808
@pytest.mark.parametrize("normalization", all_normalizations)
809
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
810
811
812
813
814
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
815
        if not is_fp8_supported(config):
816
            pytest.skip("Model config does not support FP8")
817
818
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
819

Przemek Tredak's avatar
Przemek Tredak committed
820
821
822
823
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

824
825
826
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
827
        config.num_heads,
828
829
830
831
832
833
834
835
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
836
        self_attn_mask_type="causal",
837
838
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
839
840
    )

841
842
843
844
845
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
846
847
        margin=0,
        fp8_format=recipe.Format.E4M3,
848
849
850
851
852
853
854
855
856
857
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
858
859
860
861


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
862
@pytest.mark.parametrize("model", ["small"])
863
@pytest.mark.parametrize("skip_wgrad", all_boolean)
864
@pytest.mark.parametrize("normalization", all_normalizations)
865
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
866
867
868
869
870
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
871
        if not is_fp8_supported(config):
872
            pytest.skip("Model config does not support FP8")
873
874
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
875

Przemek Tredak's avatar
Przemek Tredak committed
876
877
878
879
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

880
881
882
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
883
        config.num_heads,
884
885
886
887
888
889
890
891
892
893
894
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
895
896
    )

897
898
899
900
901
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
902
903
        margin=0,
        fp8_format=recipe.Format.E4M3,
904
905
906
907
908
909
910
911
912
913
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
914
915
916
917


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
918
@pytest.mark.parametrize("model", ["small"])
919
@pytest.mark.parametrize("skip_wgrad", all_boolean)
920
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
921
922
    config = model_configs[model]

923
    if fp8_recipe is not None:
924
925
926
927
928
929
930
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
931
            pytest.skip("Model config does not support FP8")
932
933
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
934

Przemek Tredak's avatar
Przemek Tredak committed
935
936
937
938
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

939
940
941
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
942
        config.num_heads,
943
944
945
946
947
948
949
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
950
951
    )

952
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
953
954
955
956


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
957
@pytest.mark.parametrize("model", ["small"])
958
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
959
960
    config = model_configs[model]

961
    if fp8_recipe is not None:
962
963
964
965
966
967
968
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
969
            pytest.skip("Model config does not support FP8")
970
971
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
972

Przemek Tredak's avatar
Przemek Tredak committed
973
974
975
976
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

977
978
979
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
980
        config.num_heads,
981
982
983
984
985
986
987
988
989
990
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
991
992
    )

993
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
994
995
996
997


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
998
@pytest.mark.parametrize("model", ["small"])
999
@pytest.mark.parametrize("skip_wgrad", all_boolean)
1000
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
1001
1002
    config = model_configs[model]

1003
    if fp8_recipe is not None:
1004
1005
1006
1007
1008
1009
1010
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
1011
            pytest.skip("Model config does not support FP8")
1012
1013
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
1014

Przemek Tredak's avatar
Przemek Tredak committed
1015
1016
1017
1018
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1019
1020
1021
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
1022
        config.num_heads,
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
1033
1034
    )

1035
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
1036
1037
1038
1039


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
1040
@pytest.mark.parametrize("model", ["small"])
1041
@pytest.mark.parametrize("skip_wgrad", all_boolean)
1042
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
1043
1044
    config = model_configs[model]

1045
    if fp8_recipe is not None:
1046
1047
1048
1049
1050
1051
1052
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
1053
            pytest.skip("Model config does not support FP8")
1054
1055
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
1056

1057
1058
1059
1060
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1061
1062
1063
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
1064
        config.num_heads,
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
1076
1077
    )

1078
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
1079
1080


1081
def test_model_multiple_cast():
1082
1083
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
1084
1085
1086
1087
1088
1089
1090
1091
1092

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
1093
1094
1095
1096
1097
1098


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
1099
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
1100
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
1101
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
1102

1103
    _ = general_gemm(A=weight, B=inp)
1104
1105
1106
1107
1108
1109
1110
1111
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
1112
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
1113

1114
1115
1116
1117
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
1118
1119
1120

    outp_type = datatype

1121
1122
1123
1124
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
1125
1126
        weight_fp8,
        inp_fp8,
1127
        outp_type,
1128
1129
1130
        bias=None,
        use_split_accumulator=False,
    )
1131
    torch.cuda.synchronize()
1132
1133


1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

1145
1146
1147
1148
1149
1150
1151
    attrs_to_check = [
        "_quantizer",
        "_fp8_dtype",
        "_scale_inv",
        "_transpose",
        "_transpose_invalid",
    ]
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1171
1172
1173
def test_quantized_model_init_high_precision_init_val():
    """Test quantized_model_init with preserve_high_precision_init_val=True"""
    with quantized_model_init(preserve_high_precision_init_val=True):
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1233
1234


1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
1246
    with autocast():
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1274
1275
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
1300
        with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe):
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
1335
        with autocast(enabled=with_quantization, recipe=quantization_recipe):
1336
1337
            y = module(x, **kwargs)
    check_weights()