test_sanity.py 36.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
yuguo's avatar
yuguo committed
10
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
11

12
import transformer_engine.pytorch
13
14
15
16
17
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
21
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
24
25
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
26
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
27
28
    LayerNormMLP,
    TransformerLayer,
29
30
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
31
32
)
from transformer_engine.common import recipe
33
import transformer_engine_torch as tex
34
from transformer_engine.pytorch.cpp_extensions import general_gemm
35
from transformer_engine.pytorch.module.base import get_workspace
36
37
38
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
39
40
    Float8Quantizer,
    Float8Tensor,
41
)
42
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
43
from transformer_engine.pytorch.tensor.utils import replace_raw_data
44
from transformer_engine.pytorch.distributed import checkpoint
45
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
46

47
# Only run FP8 tests on supported devices.
48
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
49
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
50
51
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

72

73
74
75
76
77
78
79
80
81
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
82

83

Przemek Tredak's avatar
Przemek Tredak committed
84
model_configs = {
85
86
87
88
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
89
90
}

91
92
93
94
95
96
97
98
99
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
100

101
param_types = [torch.float32, torch.float16]
102
if is_bf16_compatible():  # bf16 requires sm_80 or higher
103
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
104

105
all_boolean = [True, False]
106
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
107

108
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
109
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
110

111

schetlur-nv's avatar
schetlur-nv committed
112
113
114
115
116
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


117
118
119
120
121
122
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


123
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
124
    te_inp_hidden_states = torch.randn(
125
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
126
127
128
129
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
130
    te_inp_hidden_states.retain_grad()
131
132
    te_inp_attn_mask = torch.randint(
        2,
133
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
134
135
136
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
137
138
139
140

    if skip_wgrad:
        _disable_wgrads(block)

141
142
143
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
144
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
145
146
147
148
149
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

150
    assert te_out.dtype == dtype, "AMP wrong output type."
151
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
152
153
154
155
156
157
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


158
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
159
    te_inp_hidden_states = torch.randn(
160
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
161
162
163
164
165
166
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
167
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
168
169
170
        dtype=torch.bool,
        device="cuda",
    )
171
172
173
174
175
176
177
178
179
180
181
182

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
183
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
184
185
186
187
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

188
    failed_grads = []
189
190
191
192
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
193
194
195
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
196

Przemek Tredak's avatar
Przemek Tredak committed
197

198
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
199
    te_inp_hidden_states = torch.randn(
200
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
201
202
203
204
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
205
206
207
208
209

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
210
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
211
212
213
214
215
216
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


217
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
218
    te_inp_hidden_states = torch.randn(
219
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
220
221
222
223
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
224

225
226
    te_inp_attn_mask = torch.randint(
        2,
227
        (config.batch_size, 1, 1, config.max_seqlen_q),
228
229
230
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
231
232
233
234

    if skip_wgrad:
        _disable_wgrads(block)

235
236
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
237
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
238
239
240
241
242
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


243
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
244
    te_inp_hidden_states = torch.randn(
245
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
246
247
248
249
250
251
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
252
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
253
254
255
256
257
258
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
259
        (config.batch_size, 1, 1, config.max_seqlen_kv),
260
261
262
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
263
264
265
266

    if skip_wgrad:
        _disable_wgrads(block)

267
268
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
269
        te_out = block(
270
271
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
272
273
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
274
275
276
277
278
279
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


280
281
282
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
283
284
285
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
286
    te_inp = torch.randn(
287
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
288
289
290
291
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
292
293
294
295

    if skip_wgrad:
        _disable_wgrads(block)

296
297
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
298
299
300
301
302
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
303
304
305
306
307
308
309
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


310
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
311
312
313
314
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
315
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
316
317
318
        device="cuda",
        requires_grad=True,
    )
319
320
321
322
323
324
325
326
327
328
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
329
    assert te_inp.grad is not None, "Gradient should not be empty"
330
331
332
333
334
335
336
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
337
@pytest.mark.parametrize("model", ["small", "weird"])
338
339
340
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
341
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
342
343
344
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

345
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
346
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
347
348


Przemek Tredak's avatar
Przemek Tredak committed
349
350
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
351
@pytest.mark.parametrize("model", ["small", "weird"])
352
353
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
354
@pytest.mark.parametrize("skip_dgrad", all_boolean)
355
@pytest.mark.parametrize("normalization", all_normalizations)
356
@pytest.mark.parametrize("microbatching", all_boolean)
357
def test_sanity_layernorm_linear(
358
359
360
361
362
363
364
365
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
366
):
367
368
369
    config = model_configs[model]

    if fp8_recipe is not None:
370
        if not is_fp8_supported(config):
371
            pytest.skip("Model config does not support FP8")
372

Przemek Tredak's avatar
Przemek Tredak committed
373
374
375
    sigma = 0.023
    init_method = init_method_normal(sigma)

376
377
378
379
380
381
382
383
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
384
    )
385
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
386
387
388
389


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
390
@pytest.mark.parametrize("model", ["small", "weird"])
391
@pytest.mark.parametrize("skip_wgrad", all_boolean)
392
@pytest.mark.parametrize("skip_dgrad", all_boolean)
393
394
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
395
396
    config = model_configs[model]

397
    if fp8_recipe is not None:
398
        if not is_fp8_supported(config):
399
400
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
401
402
403
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

404
405
406
407
408
409
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
410
    )
411
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
412
413


414
415
416
417
418
419
420
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
421
422
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
423
424
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
425
    num_tokens = bs * config.max_seqlen_q
426
427

    if fp8_recipe is not None:
428
        if not is_fp8_supported(config):
429
430
431
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
432
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
433
434
435
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
436
437
438
439
440
441
442
443
444
445
446

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


447
448
449
450
451
452
453
454
455
456
457
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
458
459
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
460
461
462
463
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
464
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
465
466

    if fp8_recipe is not None:
467
        if not is_fp8_supported(config):
468
469
470
471
472
473
474
475
476
477
478
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
479
    m_splits = [bs * config.max_seqlen_q] * num_gemms
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
494
495
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
496
@pytest.mark.parametrize("model", ["small", "weird"])
497
498
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
499
@pytest.mark.parametrize("skip_dgrad", all_boolean)
500
@pytest.mark.parametrize("activation", all_activations)
501
@pytest.mark.parametrize("normalization", all_normalizations)
502
@pytest.mark.parametrize("microbatching", all_boolean)
503
def test_sanity_layernorm_mlp(
504
505
506
507
508
509
510
511
512
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
513
):
514
515
516
    config = model_configs[model]

    if fp8_recipe is not None:
517
        if not is_fp8_supported(config):
518
            pytest.skip("Model config does not support FP8")
519

Przemek Tredak's avatar
Przemek Tredak committed
520
521
522
523
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

524
525
526
527
528
529
530
531
532
533
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
534
    )
535
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
536
537
538
539


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
540
@pytest.mark.parametrize("model", ["small"])
541
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
542
@pytest.mark.parametrize("bias", all_boolean)
543
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
544
@pytest.mark.parametrize("normalization", all_normalizations)
545
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
546
547
548
549
550
551
552
553
554
555
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
556
557
558
    config = model_configs[model]

    if fp8_recipe is not None:
559
        if not is_fp8_supported(config):
560
            pytest.skip("Model config does not support FP8")
561

Przemek Tredak's avatar
Przemek Tredak committed
562
563
564
565
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

566
567
568
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
569
        config.num_heads,
570
571
572
573
574
575
576
577
578
579
580
581
582
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
583
584
    )

585
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
586
587
588
589
590
591


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
592
593
            margin=0,
            fp8_format=recipe.Format.E4M3,
594
595
596
597
598
599
600
601
602
603
604
605
606
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
607
608
609
610


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
611
@pytest.mark.parametrize("model", ["small"])
612
@pytest.mark.parametrize("skip_wgrad", all_boolean)
613
@pytest.mark.parametrize("normalization", all_normalizations)
614
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
615
616
617
618
619
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
620
        if not is_fp8_supported(config):
621
            pytest.skip("Model config does not support FP8")
622

Przemek Tredak's avatar
Przemek Tredak committed
623
624
625
626
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

627
628
629
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
630
        config.num_heads,
631
632
633
634
635
636
637
638
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
639
        self_attn_mask_type="causal",
640
641
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
642
643
    )

644
645
646
647
648
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
649
650
        margin=0,
        fp8_format=recipe.Format.E4M3,
651
652
653
654
655
656
657
658
659
660
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
661
662
663
664


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
665
@pytest.mark.parametrize("model", ["small"])
666
@pytest.mark.parametrize("skip_wgrad", all_boolean)
667
@pytest.mark.parametrize("normalization", all_normalizations)
668
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
669
670
671
672
673
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
674
        if not is_fp8_supported(config):
675
            pytest.skip("Model config does not support FP8")
676

Przemek Tredak's avatar
Przemek Tredak committed
677
678
679
680
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

681
682
683
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
684
        config.num_heads,
685
686
687
688
689
690
691
692
693
694
695
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
696
697
    )

698
699
700
701
702
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
703
704
        margin=0,
        fp8_format=recipe.Format.E4M3,
705
706
707
708
709
710
711
712
713
714
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
715
716
717
718


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
719
@pytest.mark.parametrize("model", ["small"])
720
@pytest.mark.parametrize("skip_wgrad", all_boolean)
721
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
722
723
    config = model_configs[model]

724
    if fp8_recipe is not None:
725
        if not is_fp8_supported(config):
726
727
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
728
729
730
731
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

732
733
734
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
735
        config.num_heads,
736
737
738
739
740
741
742
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
743
744
    )

745
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
746
747
748
749


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
750
@pytest.mark.parametrize("model", ["small"])
751
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
752
753
    config = model_configs[model]

754
    if fp8_recipe is not None:
755
        if not is_fp8_supported(config):
756
757
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
758
759
760
761
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

762
763
764
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
765
        config.num_heads,
766
767
768
769
770
771
772
773
774
775
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
776
777
    )

778
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
779
780
781
782


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
783
@pytest.mark.parametrize("model", ["small"])
784
@pytest.mark.parametrize("skip_wgrad", all_boolean)
785
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
786
787
    config = model_configs[model]

788
    if fp8_recipe is not None:
789
        if not is_fp8_supported(config):
790
791
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
792
793
794
795
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

796
797
798
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
799
        config.num_heads,
800
801
802
803
804
805
806
807
808
809
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
810
811
    )

812
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
813
814
815
816


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
817
@pytest.mark.parametrize("model", ["small"])
818
@pytest.mark.parametrize("skip_wgrad", all_boolean)
819
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
820
821
    config = model_configs[model]

822
    if fp8_recipe is not None:
823
        if not is_fp8_supported(config):
824
825
            pytest.skip("Model config does not support FP8")

826
827
828
829
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

830
831
832
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
833
        config.num_heads,
834
835
836
837
838
839
840
841
842
843
844
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
845
846
    )

847
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
848
849


850
def test_model_multiple_cast():
851
852
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
853
854
855
856
857
858
859
860
861

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
862
863
864
865
866
867


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
868
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
869
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
870
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
871

872
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
873
874
875
876
877
878
879
880
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
881
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
882

883
884
885
886
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
887
888
889

    outp_type = datatype

890
891
892
893
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
894
895
896
        weight_fp8,
        inp_fp8,
        get_workspace(),
897
        outp_type,
898
899
900
        bias=None,
        use_split_accumulator=False,
    )
901
    torch.cuda.synchronize()
902
903


904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
997
998


999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
    with fp8_autocast():
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1038
1039
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()