test_sanity.py 36.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
yuguo's avatar
yuguo committed
10
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
11

12
import transformer_engine.pytorch
13
14
15
16
17
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
21
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
22
23
24
25
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
26
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
27
28
    LayerNormMLP,
    TransformerLayer,
29
30
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
31
32
)
from transformer_engine.common import recipe
33
import transformer_engine_torch as tex
34
from transformer_engine.pytorch.cpp_extensions import general_gemm
35
from transformer_engine.pytorch.module.base import get_workspace
36
37
38
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
39
40
    Float8Quantizer,
    Float8Tensor,
41
)
42
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
43
from transformer_engine.pytorch.tensor.utils import replace_raw_data
44
from transformer_engine.pytorch.distributed import checkpoint
45
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
46

47
# Only run FP8 tests on supported devices.
48
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
49
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
50
51
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

72

73
74
75
76
77
78
79
80
81
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
82

83

Przemek Tredak's avatar
Przemek Tredak committed
84
model_configs = {
85
86
87
88
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
89
90
}

91
92
93
94
95
96
97
98
99
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
100

101
param_types = [torch.float32, torch.float16]
102
if is_bf16_compatible():  # bf16 requires sm_80 or higher
103
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
104

105
all_boolean = [True, False]
106
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
107

108
109
110
111
112
113
114
115
116
117
118
119
all_activations = [
    "gelu",
    "geglu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
]
120
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
121

122

schetlur-nv's avatar
schetlur-nv committed
123
124
125
126
127
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


128
129
130
131
132
133
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


134
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
135
    te_inp_hidden_states = torch.randn(
136
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
137
138
139
140
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
141
    te_inp_hidden_states.retain_grad()
142
143
    te_inp_attn_mask = torch.randint(
        2,
144
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
145
146
147
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
148
149
150
151

    if skip_wgrad:
        _disable_wgrads(block)

152
153
154
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
155
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
156
157
158
159
160
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

161
    assert te_out.dtype == dtype, "AMP wrong output type."
162
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
163
164
165
166
167
168
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


169
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
170
    te_inp_hidden_states = torch.randn(
171
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
172
173
174
175
176
177
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
178
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
179
180
181
        dtype=torch.bool,
        device="cuda",
    )
182
183
184
185
186
187
188
189
190
191
192
193

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
194
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
195
196
197
198
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

199
    failed_grads = []
200
201
202
203
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
204
205
206
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
207

Przemek Tredak's avatar
Przemek Tredak committed
208

209
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
210
    te_inp_hidden_states = torch.randn(
211
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
212
213
214
215
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
216
217
218
219
220

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
221
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
222
223
224
225
226
227
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


228
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
229
    te_inp_hidden_states = torch.randn(
230
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
231
232
233
234
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
235

236
237
    te_inp_attn_mask = torch.randint(
        2,
238
        (config.batch_size, 1, 1, config.max_seqlen_q),
239
240
241
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
242
243
244
245

    if skip_wgrad:
        _disable_wgrads(block)

246
247
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
248
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
249
250
251
252
253
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


254
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
255
    te_inp_hidden_states = torch.randn(
256
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
257
258
259
260
261
262
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
263
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
264
265
266
267
268
269
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
270
        (config.batch_size, 1, 1, config.max_seqlen_kv),
271
272
273
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
274
275
276
277

    if skip_wgrad:
        _disable_wgrads(block)

278
279
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
280
        te_out = block(
281
282
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
283
284
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
285
286
287
288
289
290
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


291
292
293
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
294
295
296
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
297
    te_inp = torch.randn(
298
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
299
300
301
302
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
303
304
305
306

    if skip_wgrad:
        _disable_wgrads(block)

307
308
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
309
310
311
312
313
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
314
315
316
317
318
319
320
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


321
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
322
323
324
325
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
326
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
327
328
329
        device="cuda",
        requires_grad=True,
    )
330
331
332
333
334
335
336
337
338
339
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
340
    assert te_inp.grad is not None, "Gradient should not be empty"
341
342
343
344
345
346
347
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
348
@pytest.mark.parametrize("model", ["small", "weird"])
349
350
351
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
352
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
353
354
355
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

356
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
357
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
358
359


Przemek Tredak's avatar
Przemek Tredak committed
360
361
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
362
@pytest.mark.parametrize("model", ["small", "weird"])
363
364
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
365
@pytest.mark.parametrize("skip_dgrad", all_boolean)
366
@pytest.mark.parametrize("normalization", all_normalizations)
367
@pytest.mark.parametrize("microbatching", all_boolean)
368
def test_sanity_layernorm_linear(
369
370
371
372
373
374
375
376
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
377
):
378
379
380
    config = model_configs[model]

    if fp8_recipe is not None:
381
        if not is_fp8_supported(config):
382
            pytest.skip("Model config does not support FP8")
383

Przemek Tredak's avatar
Przemek Tredak committed
384
385
386
    sigma = 0.023
    init_method = init_method_normal(sigma)

387
388
389
390
391
392
393
394
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
395
    )
396
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
397
398
399
400


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
401
@pytest.mark.parametrize("model", ["small", "weird"])
402
@pytest.mark.parametrize("skip_wgrad", all_boolean)
403
@pytest.mark.parametrize("skip_dgrad", all_boolean)
404
405
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
406
407
    config = model_configs[model]

408
    if fp8_recipe is not None:
409
        if not is_fp8_supported(config):
410
411
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
412
413
414
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

415
416
417
418
419
420
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
421
    )
422
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
423
424


425
426
427
428
429
430
431
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
432
433
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
434
435
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
436
    num_tokens = bs * config.max_seqlen_q
437
438

    if fp8_recipe is not None:
439
        if not is_fp8_supported(config):
440
441
442
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
443
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
444
445
446
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
447
448
449
450
451
452
453
454
455
456
457

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


458
459
460
461
462
463
464
465
466
467
468
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
469
470
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
471
472
473
474
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
475
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
476
477

    if fp8_recipe is not None:
478
        if not is_fp8_supported(config):
479
480
481
482
483
484
485
486
487
488
489
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
490
    m_splits = [bs * config.max_seqlen_q] * num_gemms
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
505
506
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
507
@pytest.mark.parametrize("model", ["small", "weird"])
508
509
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
510
@pytest.mark.parametrize("skip_dgrad", all_boolean)
511
@pytest.mark.parametrize("activation", all_activations)
512
@pytest.mark.parametrize("normalization", all_normalizations)
513
@pytest.mark.parametrize("microbatching", all_boolean)
514
def test_sanity_layernorm_mlp(
515
516
517
518
519
520
521
522
523
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
524
):
525
526
527
    config = model_configs[model]

    if fp8_recipe is not None:
528
        if not is_fp8_supported(config):
529
            pytest.skip("Model config does not support FP8")
530

Przemek Tredak's avatar
Przemek Tredak committed
531
532
533
534
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

535
536
537
538
539
540
541
542
543
544
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
545
    )
546
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
547
548
549
550


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
551
@pytest.mark.parametrize("model", ["small"])
552
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
553
@pytest.mark.parametrize("bias", all_boolean)
554
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
555
@pytest.mark.parametrize("normalization", all_normalizations)
556
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
557
558
559
560
561
562
563
564
565
566
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
567
568
569
    config = model_configs[model]

    if fp8_recipe is not None:
570
        if not is_fp8_supported(config):
571
            pytest.skip("Model config does not support FP8")
572

Przemek Tredak's avatar
Przemek Tredak committed
573
574
575
576
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

577
578
579
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
580
        config.num_heads,
581
582
583
584
585
586
587
588
589
590
591
592
593
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
594
595
    )

596
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
597
598
599
600
601
602


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
603
604
            margin=0,
            fp8_format=recipe.Format.E4M3,
605
606
607
608
609
610
611
612
613
614
615
616
617
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
618
619
620
621


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
622
@pytest.mark.parametrize("model", ["small"])
623
@pytest.mark.parametrize("skip_wgrad", all_boolean)
624
@pytest.mark.parametrize("normalization", all_normalizations)
625
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
626
627
628
629
630
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
631
        if not is_fp8_supported(config):
632
            pytest.skip("Model config does not support FP8")
633

Przemek Tredak's avatar
Przemek Tredak committed
634
635
636
637
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

638
639
640
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
641
        config.num_heads,
642
643
644
645
646
647
648
649
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
650
        self_attn_mask_type="causal",
651
652
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
653
654
    )

655
656
657
658
659
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
660
661
        margin=0,
        fp8_format=recipe.Format.E4M3,
662
663
664
665
666
667
668
669
670
671
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
672
673
674
675


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
676
@pytest.mark.parametrize("model", ["small"])
677
@pytest.mark.parametrize("skip_wgrad", all_boolean)
678
@pytest.mark.parametrize("normalization", all_normalizations)
679
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
680
681
682
683
684
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
685
        if not is_fp8_supported(config):
686
            pytest.skip("Model config does not support FP8")
687

Przemek Tredak's avatar
Przemek Tredak committed
688
689
690
691
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

692
693
694
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
695
        config.num_heads,
696
697
698
699
700
701
702
703
704
705
706
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
707
708
    )

709
710
711
712
713
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
714
715
        margin=0,
        fp8_format=recipe.Format.E4M3,
716
717
718
719
720
721
722
723
724
725
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
726
727
728
729


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
730
@pytest.mark.parametrize("model", ["small"])
731
@pytest.mark.parametrize("skip_wgrad", all_boolean)
732
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
733
734
    config = model_configs[model]

735
    if fp8_recipe is not None:
736
        if not is_fp8_supported(config):
737
738
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
739
740
741
742
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

743
744
745
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
746
        config.num_heads,
747
748
749
750
751
752
753
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
754
755
    )

756
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
757
758
759
760


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
761
@pytest.mark.parametrize("model", ["small"])
762
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
763
764
    config = model_configs[model]

765
    if fp8_recipe is not None:
766
        if not is_fp8_supported(config):
767
768
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
769
770
771
772
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

773
774
775
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
776
        config.num_heads,
777
778
779
780
781
782
783
784
785
786
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
787
788
    )

789
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
790
791
792
793


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
794
@pytest.mark.parametrize("model", ["small"])
795
@pytest.mark.parametrize("skip_wgrad", all_boolean)
796
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
797
798
    config = model_configs[model]

799
    if fp8_recipe is not None:
800
        if not is_fp8_supported(config):
801
802
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
803
804
805
806
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

807
808
809
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
810
        config.num_heads,
811
812
813
814
815
816
817
818
819
820
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
821
822
    )

823
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
824
825
826
827


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
828
@pytest.mark.parametrize("model", ["small"])
829
@pytest.mark.parametrize("skip_wgrad", all_boolean)
830
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
831
832
    config = model_configs[model]

833
    if fp8_recipe is not None:
834
        if not is_fp8_supported(config):
835
836
            pytest.skip("Model config does not support FP8")

837
838
839
840
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

841
842
843
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
844
        config.num_heads,
845
846
847
848
849
850
851
852
853
854
855
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
856
857
    )

858
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
859
860


861
def test_model_multiple_cast():
862
863
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
864
865
866
867
868
869
870
871
872

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
873
874
875
876
877
878


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
879
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
880
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
881
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
882

883
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
884
885
886
887
888
889
890
891
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
892
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
893

894
895
896
897
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
898
899
900

    outp_type = datatype

901
902
903
904
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
905
906
907
        weight_fp8,
        inp_fp8,
        get_workspace(),
908
        outp_type,
909
910
911
        bias=None,
        use_split_accumulator=False,
    )
912
    torch.cuda.synchronize()
913
914


915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
1008
1009


1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
    with fp8_autocast():
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1049
1050
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()