test_sanity.py 44.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
import torch
import pytest
11
import os
Przemek Tredak's avatar
Przemek Tredak committed
12

13
import transformer_engine.pytorch
14
15
16
17
18
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
19
from transformer_engine.pytorch.utils import (
20
    get_device_compute_capability,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
    init_method_normal,
    scaled_init_method_normal,
23
    is_bf16_compatible,
24
    get_cudnn_version,
Przemek Tredak's avatar
Przemek Tredak committed
25
26
27
28
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
29
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
30
31
    LayerNormMLP,
    TransformerLayer,
32
33
    RMSNorm,
    LayerNorm,
34
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
35
36
)
from transformer_engine.common import recipe
37
import transformer_engine_torch as tex
38
from transformer_engine.pytorch.cpp_extensions import general_gemm
39
from transformer_engine.pytorch.module.base import get_workspace
40
41
42
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
43
44
    Float8Quantizer,
    Float8Tensor,
45
)
46
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
47
from transformer_engine.pytorch.tensor.utils import replace_raw_data
48
from transformer_engine.pytorch.distributed import checkpoint
49
from utils import ModelConfig, dtype_tols
Przemek Tredak's avatar
Przemek Tredak committed
50

51
# Only run FP8 tests on supported devices.
52
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
53
54
55
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
56
57
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

78
79
80
81
82
83
84

def create_meta(scale_factor: float, size: int = 1):
    meta = tex.FP8TensorMeta()
    meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
    meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
    meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
    return meta
85

Przemek Tredak's avatar
Przemek Tredak committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

105

106
107
108
109
110
111
112
113
114
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
115

116

Przemek Tredak's avatar
Przemek Tredak committed
117
model_configs = {
118
119
120
121
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
122
123
124
}

fp8_recipes = [
125
126
127
    None,  # Test non-FP8
    recipe.MXFP8BlockScaling(),  # Test default
    recipe.Float8CurrentScaling(),  # Test default
128
    recipe.Float8BlockScaling(),  # Test default
129
130
    recipe.DelayedScaling(),  # Test default
    recipe.DelayedScaling(  # Test most_recent algo
131
132
        amax_history_len=16,
        amax_compute_algo="most_recent",
Przemek Tredak's avatar
Przemek Tredak committed
133
    ),
134
    recipe.DelayedScaling(  # Test custom amax and scale compute algo
135
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
136
137
138
139
140
        amax_compute_algo=custom_amax_compute,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

141
param_types = [torch.float32, torch.float16]
142
if is_bf16_compatible():  # bf16 requires sm_80 or higher
143
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
144

145
all_boolean = [True, False]
146
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
147

148
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
149
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
150

151

schetlur-nv's avatar
schetlur-nv committed
152
153
154
155
156
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


157
158
159
160
161
162
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


163
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
164
165
166
167
168
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
169
    static_input = torch.randn(
170
        config.max_seqlen_q,
171
172
173
174
175
176
177
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
    )
    static_target = torch.randn(
178
        config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
179
    )
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
194
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
195
196
197
198
199
200
201
202
203
204
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
205
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


220
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
221
    te_inp_hidden_states = torch.randn(
222
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
223
224
225
226
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
227
    te_inp_hidden_states.retain_grad()
228
229
    te_inp_attn_mask = torch.randint(
        2,
230
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
231
232
233
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
234
235
236
237

    if skip_wgrad:
        _disable_wgrads(block)

238
239
240
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
241
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
242
243
244
245
246
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

247
    assert te_out.dtype == dtype, "AMP wrong output type."
248
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
249
250
251
252
253
254
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


255
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
256
    te_inp_hidden_states = torch.randn(
257
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
258
259
260
261
262
263
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
264
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
265
266
267
        dtype=torch.bool,
        device="cuda",
    )
268
269
270
271
272
273
274
275
276
277
278
279

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
280
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
281
282
283
284
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

285
    failed_grads = []
286
287
288
289
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
290
291
292
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
293

Przemek Tredak's avatar
Przemek Tredak committed
294

295
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
296
    te_inp_hidden_states = torch.randn(
297
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
298
299
300
301
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
302
303
304
305

    if skip_wgrad:
        _disable_wgrads(block)

306
307
308
309
310
311
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

312
    use_fp8 = fp8_recipe is not None
313
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
314
        te_out = block(te_inp_hidden_states)
315
    te_out = sync_function(te_out)
316
317
318
319
320
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


321
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
322
    te_inp_hidden_states = torch.randn(
323
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
324
325
326
327
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
328

329
330
    te_inp_attn_mask = torch.randint(
        2,
331
        (config.batch_size, 1, 1, config.max_seqlen_q),
332
333
334
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
335
336
337
338

    if skip_wgrad:
        _disable_wgrads(block)

339
340
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
341
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
342
343
344
345
346
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


347
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
348
    te_inp_hidden_states = torch.randn(
349
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
350
351
352
353
354
355
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
356
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
357
358
359
360
361
362
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
363
        (config.batch_size, 1, 1, config.max_seqlen_kv),
364
365
366
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
367
368
369
370

    if skip_wgrad:
        _disable_wgrads(block)

371
372
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
373
        te_out = block(
374
375
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
376
377
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
378
379
380
381
382
383
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


384
385
386
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
387
388
389
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
390
    te_inp = torch.randn(
391
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
392
393
394
395
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
396
397
398
399

    if skip_wgrad:
        _disable_wgrads(block)

400
401
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
402
403
404
405
406
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
407
408
409
410
411
412
413
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


414
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
415
416
417
418
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
419
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
420
421
422
        device="cuda",
        requires_grad=True,
    )
423
424
425
426
427
428
429
430
431
432
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
433
    assert te_inp.grad is not None, "Gradient should not be empty"
434
435
436
437
438
439
440
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
441
@pytest.mark.parametrize("model", ["small", "weird"])
442
443
444
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
445
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
446
447
448
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

449
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
450
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
451
452


Przemek Tredak's avatar
Przemek Tredak committed
453
454
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
455
@pytest.mark.parametrize("model", ["small", "weird"])
456
457
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
458
@pytest.mark.parametrize("skip_dgrad", all_boolean)
459
@pytest.mark.parametrize("normalization", all_normalizations)
460
@pytest.mark.parametrize("microbatching", all_boolean)
461
def test_sanity_layernorm_linear(
462
463
464
465
466
467
468
469
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
470
):
471
472
473
474
475
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
476
477
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
478
479
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
480
        if not is_fp8_supported(config):
481
            pytest.skip("Model config does not support FP8")
482

Przemek Tredak's avatar
Przemek Tredak committed
483
484
485
    sigma = 0.023
    init_method = init_method_normal(sigma)

486
487
488
489
490
491
492
493
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
494
    )
495
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
496
497
498
499


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
500
@pytest.mark.parametrize("model", ["small", "weird"])
501
@pytest.mark.parametrize("skip_wgrad", all_boolean)
502
@pytest.mark.parametrize("skip_dgrad", all_boolean)
503
504
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
505
506
    config = model_configs[model]

507
508
509
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
510
511
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
512
513
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
514
        if not is_fp8_supported(config):
515
516
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
517
518
519
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

520
521
522
523
524
525
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
526
    )
527
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
528
529


530
531
532
533
534
535
536
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
537
538
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
539
540
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
541
    num_tokens = bs * config.max_seqlen_q
542
543
544
545

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
546
547
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
548
549
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
550
        if not is_fp8_supported(config):
551
552
553
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
554
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
555
556
557
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
558
559
560
561
562
563
564
565
566
567
568

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


569
570
571
572
573
574
575
576
577
578
579
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
580
581
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
582
583
584
585
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
586
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
587
588
589
590

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
591
592
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
593
594
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
595
        if not is_fp8_supported(config):
596
597
598
599
600
601
602
603
604
605
606
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
607
    m_splits = [bs * config.max_seqlen_q] * num_gemms
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
622
623
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
624
@pytest.mark.parametrize("model", ["small", "weird"])
625
626
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
627
@pytest.mark.parametrize("skip_dgrad", all_boolean)
628
@pytest.mark.parametrize("activation", all_activations)
629
@pytest.mark.parametrize("normalization", all_normalizations)
630
@pytest.mark.parametrize("microbatching", all_boolean)
631
def test_sanity_layernorm_mlp(
632
633
634
635
636
637
638
639
640
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
641
):
642
643
644
645
646
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
647
648
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
649
650
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
651
        if not is_fp8_supported(config):
652
            pytest.skip("Model config does not support FP8")
653

Przemek Tredak's avatar
Przemek Tredak committed
654
655
656
657
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

658
659
660
661
662
663
664
665
666
667
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
668
    )
669
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
670
671
672
673


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
674
@pytest.mark.parametrize("model", ["small"])
675
676
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
677
@pytest.mark.parametrize("bias", all_boolean)
678
@pytest.mark.parametrize("activation", all_activations)
679
@pytest.mark.parametrize("normalization", all_normalizations)
680
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
681
@pytest.mark.parametrize("cpu_offload", all_boolean)
682
683
684
685
686
687
688
689
690
691
692
693
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
    cpu_offload,
):
694
695
    if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("CPU offload is not supported in debug mode.")
696
697
698
699
700
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
701
702
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
703
704
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
705
        if not is_fp8_supported(config):
706
            pytest.skip("Model config does not support FP8")
707

Przemek Tredak's avatar
Przemek Tredak committed
708
709
710
711
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

712
713
714
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
715
        config.num_heads,
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
730
731
    )

732
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
733
734
735
736
737
738


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
739
740
            margin=0,
            fp8_format=recipe.Format.E4M3,
741
742
743
744
745
746
747
748
749
750
751
752
753
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
754
        cpu_offload=False,
755
    )
Przemek Tredak's avatar
Przemek Tredak committed
756
757
758
759


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
760
@pytest.mark.parametrize("model", ["small"])
761
762
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
763
@pytest.mark.parametrize("normalization", all_normalizations)
764
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
765
766
767
768
769
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
770
771
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
772
773
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
774
        if not is_fp8_supported(config):
775
            pytest.skip("Model config does not support FP8")
776

Przemek Tredak's avatar
Przemek Tredak committed
777
778
779
780
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

781
782
783
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
784
        config.num_heads,
785
786
787
788
789
790
791
792
793
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
        zero_centered_gamma=zero_centered_gamma,
794
        self_attn_mask_type="causal",
795
796
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
797
798
    )

799
800
801
802
803
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
804
805
        margin=0,
        fp8_format=recipe.Format.E4M3,
806
807
808
809
810
811
812
813
814
815
816
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
817
818
819
820


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
821
@pytest.mark.parametrize("model", ["small"])
822
823
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
824
@pytest.mark.parametrize("normalization", all_normalizations)
825
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
826
827
828
829
830
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
831
832
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
833
834
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
835
        if not is_fp8_supported(config):
836
            pytest.skip("Model config does not support FP8")
837

Przemek Tredak's avatar
Przemek Tredak committed
838
839
840
841
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

842
843
844
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
845
        config.num_heads,
846
847
848
849
850
851
852
853
854
855
856
857
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
858
859
    )

860
861
862
863
864
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
865
866
        margin=0,
        fp8_format=recipe.Format.E4M3,
867
868
869
870
871
872
873
874
875
876
877
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
878
879
880
881


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
882
@pytest.mark.parametrize("model", ["small"])
883
@pytest.mark.parametrize("skip_wgrad", all_boolean)
884
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
885
886
    config = model_configs[model]

887
888
889
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
890
891
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
892
893
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
894
        if not is_fp8_supported(config):
895
896
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
897
898
899
900
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

901
902
903
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
904
        config.num_heads,
905
906
907
908
909
910
911
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
912
913
    )

914
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
915
916
917
918


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
919
@pytest.mark.parametrize("model", ["small"])
920
@pytest.mark.parametrize("skip_wgrad", all_boolean)
921
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
922
923
    config = model_configs[model]

924
925
926
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
927
928
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
929
930
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
931
        if not is_fp8_supported(config):
932
933
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
934
935
936
937
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

938
939
940
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
941
        config.num_heads,
942
943
944
945
946
947
948
949
950
951
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
952
953
    )

954
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
955
956
957
958


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
959
@pytest.mark.parametrize("model", ["small"])
960
@pytest.mark.parametrize("skip_wgrad", all_boolean)
961
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
962
963
    config = model_configs[model]

964
965
966
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
967
968
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
969
970
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
971
        if not is_fp8_supported(config):
972
973
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
974
975
976
977
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

978
979
980
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
981
        config.num_heads,
982
983
984
985
986
987
988
989
990
991
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
992
993
    )

994
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
995
996
997
998


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
999
@pytest.mark.parametrize("model", ["small"])
1000
1001
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1002
1003
1004
def test_sanity_gradient_accumulation_fusion(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
1005
1006
    config = model_configs[model]

1007
1008
1009
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
1010
1011
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
1012
1013
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
1014
        if not is_fp8_supported(config):
1015
1016
            pytest.skip("Model config does not support FP8")

1017
1018
1019
1020
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1021
1022
1023
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
1024
        config.num_heads,
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
1037
1038
    )

1039
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
1040
1041
1042
1043


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
1044
@pytest.mark.parametrize("model", ["small"])
1045
1046
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1047
@pytest.mark.parametrize("normalization", all_normalizations)
1048
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
1049
1050
1051
1052
1053
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
1054
1055
        if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
            pytest.skip(reason_for_no_fp8_block_scaling)
1056
1057
        if fp8_recipe.mxfp8() and not mxfp8_available:
            pytest.skip(reason_for_no_mxfp8)
1058
1059
        if fp8_recipe.float8_block_scaling():
            pytest.skip("cuda graph not supported for float8_block_scaling recipe")
1060
        if not is_fp8_supported(config):
1061
            pytest.skip("Model config does not support FP8")
1062
1063
1064
1065
1066

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1067
1068
1069
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
1070
        config.num_heads,
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        normalization=normalization,
        device="cuda",
1083
1084
    )

1085
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
1086

1087

1088
def test_model_multiple_cast():
1089
1090
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
1091
1092
1093
1094
1095
1096
1097
1098
1099

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
1100
1101
1102
1103
1104
1105


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
1106
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
1107
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
1108
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
1109

1110
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
1111
1112
1113
1114
1115
1116
1117
1118
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
1119
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
1120

1121
1122
1123
1124
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
1125
1126
1127

    outp_type = datatype

1128
1129
1130
1131
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
1132
1133
1134
        weight_fp8,
        inp_fp8,
        get_workspace(),
1135
        outp_type,
1136
1137
1138
        bias=None,
        use_split_accumulator=False,
    )
1139
    torch.cuda.synchronize()
1140
1141


1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249


@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1250
1251
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()