test_sanity.py 40.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
yuguo's avatar
yuguo committed
10
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
11

12
import transformer_engine.pytorch
13
14
15
16
17
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
21
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
24
25
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
26
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
27
28
    LayerNormMLP,
    TransformerLayer,
29
30
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
31
32
)
from transformer_engine.common import recipe
33
import transformer_engine_torch as tex
34
from transformer_engine.pytorch.cpp_extensions import general_gemm
35
from transformer_engine.pytorch.module.base import get_workspace
36
37
38
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
39
40
    Float8Quantizer,
    Float8Tensor,
41
)
42
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
43
from transformer_engine.pytorch.tensor.utils import replace_raw_data
44
from transformer_engine.pytorch.distributed import checkpoint
45
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
46

47
# Only run FP8 tests on supported devices.
48
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
49
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
50
51
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

72

73
74
75
76
77
78
79
80
81
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
82

83

Przemek Tredak's avatar
Przemek Tredak committed
84
model_configs = {
85
86
87
88
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
89
90
}

91
92
93
94
95
96
97
98
99

def nvfp4_vanilla():
    nvfp4_recipe = recipe.NVFP4BlockScaling()
    nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
    nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
    nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
    return nvfp4_recipe


100
101
102
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
103
    fp8_recipes.append(nvfp4_vanilla())  # TODO: fix check for this
104
105
106
107
108
109
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
110

111
param_types = [torch.float32, torch.float16]
112
if is_bf16_compatible():  # bf16 requires sm_80 or higher
113
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
114

115
all_boolean = [True, False]
116
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
117

118
119
120
121
122
123
124
125
126
127
128
129
all_activations = [
    "gelu",
    "geglu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
]
130
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
131

132

schetlur-nv's avatar
schetlur-nv committed
133
134
135
136
137
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


138
139
140
141
142
143
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


144
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
145
    te_inp_hidden_states = torch.randn(
146
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
147
148
149
150
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
151
    te_inp_hidden_states.retain_grad()
152
153
    te_inp_attn_mask = torch.randint(
        2,
154
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
155
156
157
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
158
159
160
161

    if skip_wgrad:
        _disable_wgrads(block)

162
163
164
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
165
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
166
167
168
169
170
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

171
    assert te_out.dtype == dtype, "AMP wrong output type."
172
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
173
174
175
176
177
178
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


179
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
180
    te_inp_hidden_states = torch.randn(
181
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
182
183
184
185
186
187
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
188
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
189
190
191
        dtype=torch.bool,
        device="cuda",
    )
192
193
194
195
196
197
198
199
200
201
202
203

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
204
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
205
206
207
208
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

209
    failed_grads = []
210
211
212
213
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
214
215
216
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
217

Przemek Tredak's avatar
Przemek Tredak committed
218

219
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
220
    te_inp_hidden_states = torch.randn(
221
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
222
223
224
225
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
226
227
228
229
230

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
231
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
232
233
234
235
236
237
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


238
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
239
    te_inp_hidden_states = torch.randn(
240
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
241
242
243
244
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
245

246
247
    te_inp_attn_mask = torch.randint(
        2,
248
        (config.batch_size, 1, 1, config.max_seqlen_q),
249
250
251
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
252
253
254
255

    if skip_wgrad:
        _disable_wgrads(block)

256
257
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
258
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
259
260
261
262
263
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


264
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
265
    te_inp_hidden_states = torch.randn(
266
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
267
268
269
270
271
272
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
273
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
274
275
276
277
278
279
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
280
        (config.batch_size, 1, 1, config.max_seqlen_kv),
281
282
283
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
284
285
286
287

    if skip_wgrad:
        _disable_wgrads(block)

288
289
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
290
        te_out = block(
291
292
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
293
294
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
295
296
297
298
299
300
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


301
302
303
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
304
305
306
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
307
    te_inp = torch.randn(
308
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
309
310
311
312
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
313
314
315
316

    if skip_wgrad:
        _disable_wgrads(block)

317
318
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
319
320
321
322
323
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
324
325
326
327
328
329
330
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


331
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
332
333
334
335
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
336
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
337
338
339
        device="cuda",
        requires_grad=True,
    )
340
341
342
343
344
345
346
347
348
349
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
350
    assert te_inp.grad is not None, "Gradient should not be empty"
351
352
353
354
355
356
357
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
358
@pytest.mark.parametrize("model", ["small", "weird"])
359
360
361
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
362
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
363
364
365
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

366
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
367
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
368
369


Przemek Tredak's avatar
Przemek Tredak committed
370
371
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
372
@pytest.mark.parametrize("model", ["small", "weird"])
373
374
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
375
@pytest.mark.parametrize("skip_dgrad", all_boolean)
376
@pytest.mark.parametrize("normalization", all_normalizations)
377
@pytest.mark.parametrize("microbatching", all_boolean)
378
def test_sanity_layernorm_linear(
379
380
381
382
383
384
385
386
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
387
):
388
389
390
    config = model_configs[model]

    if fp8_recipe is not None:
391
392
393
394
395
396
397
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
398
            pytest.skip("Model config does not support FP8")
399
400
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
401

Przemek Tredak's avatar
Przemek Tredak committed
402
403
404
    sigma = 0.023
    init_method = init_method_normal(sigma)

405
406
407
408
409
410
411
412
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
413
    )
414
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
415
416
417
418


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
419
@pytest.mark.parametrize("model", ["small", "weird"])
420
@pytest.mark.parametrize("skip_wgrad", all_boolean)
421
@pytest.mark.parametrize("skip_dgrad", all_boolean)
422
423
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
424
425
    config = model_configs[model]

426
    if fp8_recipe is not None:
427
        if not is_fp8_supported(config):
428
            pytest.skip("Model config does not support FP8")
429
430
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
431

Przemek Tredak's avatar
Przemek Tredak committed
432
433
434
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

435
436
437
438
439
440
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
441
    )
442
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
443
444


445
446
447
448
449
450
451
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
452
453
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
454
455
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
456
    num_tokens = bs * config.max_seqlen_q
457
458

    if fp8_recipe is not None:
459
460
461
462
463
464
465
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
466
            pytest.skip("Model config does not support FP8")
467
468
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
469
470

    use_fp8 = fp8_recipe is not None
471
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
472
473
474
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
475
476
477
478
479
480
481
482
483
484
485

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


486
487
488
489
490
491
492
493
494
495
496
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
497
498
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
499
500
501
502
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
503
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
504
505

    if fp8_recipe is not None:
506
        if not is_fp8_supported(config):
507
            pytest.skip("Model config does not support FP8")
508
509
        if fp8_recipe.nvfp4():
            pytest.skip("NVFP4 not supported for grouped linear")
510
511
512
513
514
515
516
517
518
519

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
520
    m_splits = [bs * config.max_seqlen_q] * num_gemms
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
535
536
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
537
@pytest.mark.parametrize("model", ["small", "weird"])
538
539
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
540
@pytest.mark.parametrize("skip_dgrad", all_boolean)
541
@pytest.mark.parametrize("activation", all_activations)
542
@pytest.mark.parametrize("normalization", all_normalizations)
543
@pytest.mark.parametrize("microbatching", all_boolean)
544
def test_sanity_layernorm_mlp(
545
546
547
548
549
550
551
552
553
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
554
):
555
556
557
    config = model_configs[model]

    if fp8_recipe is not None:
558
559
560
561
562
563
564
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if not config.is_fp8_supported():
565
            pytest.skip("Model config does not support FP8")
566
567
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
568

Przemek Tredak's avatar
Przemek Tredak committed
569
570
571
572
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

573
574
575
576
577
578
579
580
581
582
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
583
    )
584
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
585
586
587
588


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
589
@pytest.mark.parametrize("model", ["small"])
590
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
591
@pytest.mark.parametrize("bias", all_boolean)
592
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
593
@pytest.mark.parametrize("normalization", all_normalizations)
594
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
595
596
597
598
599
600
601
602
603
604
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
605
606
607
    config = model_configs[model]

    if fp8_recipe is not None:
608
609
610
611
612
613
614
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
615
            pytest.skip("Model config does not support FP8")
616
617
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
618

Przemek Tredak's avatar
Przemek Tredak committed
619
620
621
622
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

623
624
625
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
626
        config.num_heads,
627
628
629
630
631
632
633
634
635
636
637
638
639
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
640
641
    )

642
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
643
644
645
646
647
648


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
649
650
            margin=0,
            fp8_format=recipe.Format.E4M3,
651
652
653
654
655
656
657
658
659
660
661
662
663
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
664
665
666
667


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
668
@pytest.mark.parametrize("model", ["small"])
669
@pytest.mark.parametrize("skip_wgrad", all_boolean)
670
@pytest.mark.parametrize("normalization", all_normalizations)
671
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
672
673
674
675
676
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
677
        if not is_fp8_supported(config):
678
            pytest.skip("Model config does not support FP8")
679
680
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
681

Przemek Tredak's avatar
Przemek Tredak committed
682
683
684
685
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

686
687
688
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
689
        config.num_heads,
690
691
692
693
694
695
696
697
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
698
        self_attn_mask_type="causal",
699
700
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
701
702
    )

703
704
705
706
707
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
708
709
        margin=0,
        fp8_format=recipe.Format.E4M3,
710
711
712
713
714
715
716
717
718
719
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
720
721
722
723


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
724
@pytest.mark.parametrize("model", ["small"])
725
@pytest.mark.parametrize("skip_wgrad", all_boolean)
726
@pytest.mark.parametrize("normalization", all_normalizations)
727
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
728
729
730
731
732
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
733
        if not is_fp8_supported(config):
734
            pytest.skip("Model config does not support FP8")
735
736
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
737

Przemek Tredak's avatar
Przemek Tredak committed
738
739
740
741
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

742
743
744
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
745
        config.num_heads,
746
747
748
749
750
751
752
753
754
755
756
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
757
758
    )

759
760
761
762
763
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
764
765
        margin=0,
        fp8_format=recipe.Format.E4M3,
766
767
768
769
770
771
772
773
774
775
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
776
777
778
779


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
780
@pytest.mark.parametrize("model", ["small"])
781
@pytest.mark.parametrize("skip_wgrad", all_boolean)
782
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
783
784
    config = model_configs[model]

785
    if fp8_recipe is not None:
786
787
788
789
790
791
792
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
793
            pytest.skip("Model config does not support FP8")
794
795
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
796

Przemek Tredak's avatar
Przemek Tredak committed
797
798
799
800
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

801
802
803
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
804
        config.num_heads,
805
806
807
808
809
810
811
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
812
813
    )

814
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
815
816
817
818


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
819
@pytest.mark.parametrize("model", ["small"])
820
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
821
822
    config = model_configs[model]

823
    if fp8_recipe is not None:
824
825
826
827
828
829
830
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
831
            pytest.skip("Model config does not support FP8")
832
833
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
834

Przemek Tredak's avatar
Przemek Tredak committed
835
836
837
838
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

839
840
841
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
842
        config.num_heads,
843
844
845
846
847
848
849
850
851
852
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
853
854
    )

855
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
856
857
858
859


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
860
@pytest.mark.parametrize("model", ["small"])
861
@pytest.mark.parametrize("skip_wgrad", all_boolean)
862
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
863
864
    config = model_configs[model]

865
    if fp8_recipe is not None:
866
867
868
869
870
871
872
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
873
            pytest.skip("Model config does not support FP8")
874
875
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
876

Przemek Tredak's avatar
Przemek Tredak committed
877
878
879
880
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

881
882
883
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
884
        config.num_heads,
885
886
887
888
889
890
891
892
893
894
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
895
896
    )

897
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
898
899
900
901


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
902
@pytest.mark.parametrize("model", ["small"])
903
@pytest.mark.parametrize("skip_wgrad", all_boolean)
904
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
905
906
    config = model_configs[model]

907
    if fp8_recipe is not None:
908
909
910
911
912
913
914
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
        if not config.is_fp8_supported():
915
            pytest.skip("Model config does not support FP8")
916
917
        if fp8_recipe.nvfp4() and dtype == torch.float16:
            pytest.skip("FP16 output for NVFP4 not supported")
918

919
920
921
922
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

923
924
925
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
926
        config.num_heads,
927
928
929
930
931
932
933
934
935
936
937
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
938
939
    )

940
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
941
942


943
def test_model_multiple_cast():
944
945
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
946
947
948
949
950
951
952
953
954

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
955
956
957
958
959
960


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
961
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
962
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
963
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
964

965
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
966
967
968
969
970
971
972
973
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
974
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
975

976
977
978
979
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
980
981
982

    outp_type = datatype

983
984
985
986
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
987
988
989
        weight_fp8,
        inp_fp8,
        get_workspace(),
990
        outp_type,
991
992
993
        bias=None,
        use_split_accumulator=False,
    )
994
    torch.cuda.synchronize()
995
996


997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1090
1091


1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
    with fp8_autocast():
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1131
1132
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()