test_sanity.py 36.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from typing import Optional

Przemek Tredak's avatar
Przemek Tredak committed
7
8
import torch
import pytest
9
import os
Przemek Tredak's avatar
Przemek Tredak committed
10

11
import transformer_engine.pytorch
12
13
14
15
16
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
20
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
23
24
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
25
    GroupedLinear,
Przemek Tredak's avatar
Przemek Tredak committed
26
27
    LayerNormMLP,
    TransformerLayer,
28
29
    RMSNorm,
    LayerNorm,
Przemek Tredak's avatar
Przemek Tredak committed
30
31
)
from transformer_engine.common import recipe
32
import transformer_engine_torch as tex
33
from transformer_engine.pytorch.cpp_extensions import general_gemm
34
from transformer_engine.pytorch.module.base import get_workspace
35
36
37
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8CurrentScalingQuantizer,
38
39
    Float8Quantizer,
    Float8Tensor,
40
)
41
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
42
from transformer_engine.pytorch.tensor.utils import replace_raw_data
43
from transformer_engine.pytorch.distributed import checkpoint
44
from utils import ModelConfig
Przemek Tredak's avatar
Przemek Tredak committed
45

46
# Only run FP8 tests on supported devices.
47
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
48
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
49
50
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))


if NVTE_TEST_NVINSPECT_ENABLED:
    # The sanity tests should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

71

72
73
74
75
76
77
78
79
80
def is_fp8_supported(config: ModelConfig):
    if (
        config.max_seqlen_q * config.batch_size % 16
        or config.max_seqlen_kv * config.batch_size % 16
    ):
        return False
    if config.hidden_size % 16 or config.hidden_size_kv % 16:
        return False
    return True
Przemek Tredak's avatar
Przemek Tredak committed
81

82

Przemek Tredak's avatar
Przemek Tredak committed
83
model_configs = {
84
85
86
87
    "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
    "small": ModelConfig(2, 32, 2, 32, num_layers=2),
    "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
Przemek Tredak's avatar
Przemek Tredak committed
88
89
}

90
91
92
93
94
95
96
97
98
fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
Przemek Tredak's avatar
Przemek Tredak committed
99

100
param_types = [torch.float32, torch.float16]
101
if is_bf16_compatible():  # bf16 requires sm_80 or higher
102
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
103

104
all_boolean = [True, False]
105
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
106

107
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
108
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
109

110

schetlur-nv's avatar
schetlur-nv committed
111
112
113
114
115
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


116
117
118
119
120
121
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


122
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
123
    te_inp_hidden_states = torch.randn(
124
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
125
126
127
128
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
129
    te_inp_hidden_states.retain_grad()
130
131
    te_inp_attn_mask = torch.randint(
        2,
132
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
133
134
135
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
136
137
138
139

    if skip_wgrad:
        _disable_wgrads(block)

140
141
142
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
143
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
144
145
146
147
148
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

149
    assert te_out.dtype == dtype, "AMP wrong output type."
150
    assert te_inp_hidden_states.grad is not None, "Gradient should not be empty"
151
152
153
154
155
156
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


157
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
158
    te_inp_hidden_states = torch.randn(
159
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
160
161
162
163
164
165
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
166
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
167
168
169
        dtype=torch.bool,
        device="cuda",
    )
170
171
172
173
174
175
176
177
178
179
180
181

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
182
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
183
184
185
186
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

187
    failed_grads = []
188
189
190
191
    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
192
193
194
            if not torch.count_nonzero(p.main_grad) > 0:
                failed_grads.append(name)
    assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
195

Przemek Tredak's avatar
Przemek Tredak committed
196

197
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
198
    te_inp_hidden_states = torch.randn(
199
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
200
201
202
203
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
204
205
206
207
208

    if skip_wgrad:
        _disable_wgrads(block)

    use_fp8 = fp8_recipe is not None
209
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
210
211
212
213
214
215
        te_out = block(te_inp_hidden_states)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


216
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
217
    te_inp_hidden_states = torch.randn(
218
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
219
220
221
222
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
223

224
225
    te_inp_attn_mask = torch.randint(
        2,
226
        (config.batch_size, 1, 1, config.max_seqlen_q),
227
228
229
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
230
231
232
233

    if skip_wgrad:
        _disable_wgrads(block)

234
235
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
236
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
237
238
239
240
241
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


242
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
243
    te_inp_hidden_states = torch.randn(
244
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
245
246
247
248
249
250
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
251
        (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
252
253
254
255
256
257
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
258
        (config.batch_size, 1, 1, config.max_seqlen_kv),
259
260
261
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
262
263
264
265

    if skip_wgrad:
        _disable_wgrads(block)

266
267
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
268
        te_out = block(
269
270
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
271
272
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
273
274
275
276
277
278
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


279
280
281
def _test_sanity_common(
    block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
282
283
284
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
285
    te_inp = torch.randn(
286
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
287
288
289
290
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
291
292
293
294

    if skip_wgrad:
        _disable_wgrads(block)

295
296
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
297
298
299
300
301
        if not microbatching:
            te_out = block(te_inp)
        else:
            _ = block(te_inp, is_first_microbatch=True)
            te_out = block(te_inp, is_first_microbatch=False)
Przemek Tredak's avatar
Przemek Tredak committed
302
303
304
305
306
307
308
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


309
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
310
311
312
313
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
314
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
315
316
317
        device="cuda",
        requires_grad=True,
    )
318
319
320
321
322
323
324
325
326
327
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
328
    assert te_inp.grad is not None, "Gradient should not be empty"
329
330
331
332
333
334
335
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
336
@pytest.mark.parametrize("model", ["small", "weird"])
337
338
339
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
340
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
341
342
343
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

344
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
345
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
346
347


Przemek Tredak's avatar
Przemek Tredak committed
348
349
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
350
@pytest.mark.parametrize("model", ["small", "weird"])
351
352
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
353
@pytest.mark.parametrize("skip_dgrad", all_boolean)
354
@pytest.mark.parametrize("normalization", all_normalizations)
355
@pytest.mark.parametrize("microbatching", all_boolean)
356
def test_sanity_layernorm_linear(
357
358
359
360
361
362
363
364
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    normalization,
    microbatching,
365
):
366
367
368
    config = model_configs[model]

    if fp8_recipe is not None:
369
        if not is_fp8_supported(config):
370
            pytest.skip("Model config does not support FP8")
371

Przemek Tredak's avatar
Przemek Tredak committed
372
373
374
    sigma = 0.023
    init_method = init_method_normal(sigma)

375
376
377
378
379
380
381
382
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
383
    )
384
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
385
386
387
388


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
389
@pytest.mark.parametrize("model", ["small", "weird"])
390
@pytest.mark.parametrize("skip_wgrad", all_boolean)
391
@pytest.mark.parametrize("skip_dgrad", all_boolean)
392
393
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
Przemek Tredak's avatar
Przemek Tredak committed
394
395
    config = model_configs[model]

396
    if fp8_recipe is not None:
397
        if not is_fp8_supported(config):
398
399
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
400
401
402
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

403
404
405
406
407
408
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
409
    )
410
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
411
412


413
414
415
416
417
418
419
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
420
421
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
422
423
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
424
    num_tokens = bs * config.max_seqlen_q
425
426

    if fp8_recipe is not None:
427
        if not is_fp8_supported(config):
428
429
430
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
431
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
432
433
434
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
435
436
437
438
439
440
441
442
443
444
445

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


446
447
448
449
450
451
452
453
454
455
456
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
    dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
457
458
    if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
        pytest.skip("FP8 model parameters are not supported in debug mode.")
459
460
461
462
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
    bs = bs * 16
463
    num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
464
465

    if fp8_recipe is not None:
466
        if not is_fp8_supported(config):
467
468
469
470
471
472
473
474
475
476
477
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
        te_grouped_linear = GroupedLinear(
            num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
478
    m_splits = [bs * config.max_seqlen_q] * num_gemms
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    elif empty_split == "middle":
        m_splits[num_gemms // 2] = 0

    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_grouped_linear(inp_hidden_states, m_splits)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
493
494
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
495
@pytest.mark.parametrize("model", ["small", "weird"])
496
497
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
498
@pytest.mark.parametrize("skip_dgrad", all_boolean)
499
@pytest.mark.parametrize("activation", all_activations)
500
@pytest.mark.parametrize("normalization", all_normalizations)
501
@pytest.mark.parametrize("microbatching", all_boolean)
502
def test_sanity_layernorm_mlp(
503
504
505
506
507
508
509
510
511
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    skip_dgrad,
    activation,
    normalization,
    microbatching,
512
):
513
514
515
    config = model_configs[model]

    if fp8_recipe is not None:
516
        if not is_fp8_supported(config):
517
            pytest.skip("Model config does not support FP8")
518

Przemek Tredak's avatar
Przemek Tredak committed
519
520
521
522
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

523
524
525
526
527
528
529
530
531
532
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
533
    )
534
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
Przemek Tredak's avatar
Przemek Tredak committed
535
536
537
538


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
539
@pytest.mark.parametrize("model", ["small"])
540
@pytest.mark.parametrize("skip_wgrad", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
541
@pytest.mark.parametrize("bias", all_boolean)
542
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
543
@pytest.mark.parametrize("normalization", all_normalizations)
544
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
545
546
547
548
549
550
551
552
553
554
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
):
555
556
557
    config = model_configs[model]

    if fp8_recipe is not None:
558
        if not is_fp8_supported(config):
559
            pytest.skip("Model config does not support FP8")
560

Przemek Tredak's avatar
Przemek Tredak committed
561
562
563
564
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

565
566
567
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
568
        config.num_heads,
569
570
571
572
573
574
575
576
577
578
579
580
581
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
582
583
    )

584
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
585
586
587
588
589
590


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
591
592
            margin=0,
            fp8_format=recipe.Format.E4M3,
593
594
595
596
597
598
599
600
601
602
603
604
605
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
    )
Przemek Tredak's avatar
Przemek Tredak committed
606
607
608
609


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
610
@pytest.mark.parametrize("model", ["small"])
611
@pytest.mark.parametrize("skip_wgrad", all_boolean)
612
@pytest.mark.parametrize("normalization", all_normalizations)
613
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
614
615
616
617
618
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
619
        if not is_fp8_supported(config):
620
            pytest.skip("Model config does not support FP8")
621

Przemek Tredak's avatar
Przemek Tredak committed
622
623
624
625
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

626
627
628
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
629
        config.num_heads,
630
631
632
633
634
635
636
637
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
638
        self_attn_mask_type="causal",
639
640
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
641
642
    )

643
644
645
646
647
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
648
649
        margin=0,
        fp8_format=recipe.Format.E4M3,
650
651
652
653
654
655
656
657
658
659
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
660
661
662
663


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
664
@pytest.mark.parametrize("model", ["small"])
665
@pytest.mark.parametrize("skip_wgrad", all_boolean)
666
@pytest.mark.parametrize("normalization", all_normalizations)
667
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
668
669
670
671
672
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
673
        if not is_fp8_supported(config):
674
            pytest.skip("Model config does not support FP8")
675

Przemek Tredak's avatar
Przemek Tredak committed
676
677
678
679
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

680
681
682
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
683
        config.num_heads,
684
685
686
687
688
689
690
691
692
693
694
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
695
696
    )

697
698
699
700
701
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
702
703
        margin=0,
        fp8_format=recipe.Format.E4M3,
704
705
706
707
708
709
710
711
712
713
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
714
715
716
717


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
718
@pytest.mark.parametrize("model", ["small"])
719
@pytest.mark.parametrize("skip_wgrad", all_boolean)
720
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
721
722
    config = model_configs[model]

723
    if fp8_recipe is not None:
724
        if not is_fp8_supported(config):
725
726
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
727
728
729
730
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

731
732
733
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
734
        config.num_heads,
735
736
737
738
739
740
741
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
742
743
    )

744
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
745
746
747
748


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
749
@pytest.mark.parametrize("model", ["small"])
750
def test_sanity_drop_path(dtype, fp8_recipe, model):
Przemek Tredak's avatar
Przemek Tredak committed
751
752
    config = model_configs[model]

753
    if fp8_recipe is not None:
754
        if not is_fp8_supported(config):
755
756
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
757
758
759
760
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

761
762
763
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
764
        config.num_heads,
765
766
767
768
769
770
771
772
773
774
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
775
776
    )

777
    _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
Przemek Tredak's avatar
Przemek Tredak committed
778
779
780
781


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
782
@pytest.mark.parametrize("model", ["small"])
783
@pytest.mark.parametrize("skip_wgrad", all_boolean)
784
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
785
786
    config = model_configs[model]

787
    if fp8_recipe is not None:
788
        if not is_fp8_supported(config):
789
790
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
791
792
793
794
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

795
796
797
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
798
        config.num_heads,
799
800
801
802
803
804
805
806
807
808
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
809
810
    )

811
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
812
813
814
815


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
816
@pytest.mark.parametrize("model", ["small"])
817
@pytest.mark.parametrize("skip_wgrad", all_boolean)
818
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
819
820
    config = model_configs[model]

821
    if fp8_recipe is not None:
822
        if not is_fp8_supported(config):
823
824
            pytest.skip("Model config does not support FP8")

825
826
827
828
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

829
830
831
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
832
        config.num_heads,
833
834
835
836
837
838
839
840
841
842
843
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
844
845
    )

846
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
847
848


849
def test_model_multiple_cast():
850
851
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
852
853
854
855
856
857
858
859
860

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
861
862
863
864
865
866


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
867
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
868
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
869
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
870

871
    _ = general_gemm(A=weight, B=inp, workspace=get_workspace())
872
873
874
875
876
877
878
879
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
880
    scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
881

882
883
884
885
    scales = torch.ones(1).cuda().squeeze()
    amaxes = torch.ones(1).cuda().squeeze()
    dtype = tex.DType.kFloat8E4M3
    fp8_quantizer = Float8Quantizer(scales, amaxes, dtype)
886
887
888

    outp_type = datatype

889
890
891
892
    scratchpad_fp8 = fp8_quantizer(scratchpad)
    inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N))
    general_gemm(
893
894
895
        weight_fp8,
        inp_fp8,
        get_workspace(),
896
        outp_type,
897
898
899
        bias=None,
        use_split_accumulator=False,
    )
900
    torch.cuda.synchronize()
901
902


903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
    """Test the functionality of replace_raw_data"""
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)

    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
    attrs = {}
    for attr in attrs_to_check:
        attrs[attr] = getattr(fp8_tensor, attr)

    old_data = fp8_tensor._data
    new_data = torch.empty_like(old_data)
    replace_raw_data(fp8_tensor, new_data)

    # Make sure the new_data is properly assigned.
    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
    # Make sure the values are not changed.
    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
    # Make sure other attributes are not changed (totally identical)
    for attr in attrs_to_check:
        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
    """Test fp8_model_init with preserve_high_precision_init_val=True"""
    with fp8_model_init(preserve_high_precision_init_val=True):
        model = Linear(768, 768)

    weight = model.weight

    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
    assert hasattr(
        weight, "clear_high_precision_init_val"
    ), "clear_high_precision_init_val() not found"

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape, dtype=weight.dtype, device=weight.device
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )

    weight.clear_high_precision_init_val()
    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
    assert not hasattr(
        weight, "._high_precision_init_val"
    ), "clear_high_precision_init_val() not work"
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995


def test_sanity_checkpointing_on_callables():
    """Test that TE checkpointing works correctly on callable modules."""

    # torch.autograf.function
    class MyFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, inp):
            return inp

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output

    module = MyFunction.apply
    inp = torch.randn(10, 10, device="cuda", requires_grad=True)

    out_checkpoint = checkpoint(module, inp)
    out_checkpoint.sum().backward()
    grad_checkpoint = inp.grad

    out_standard = module(inp)
    out_standard.sum().backward()
    grad_standard = inp.grad

    # Assert that gradients are the same
    torch.testing.assert_close(grad_checkpoint, grad_standard)
996
997


998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
    """Test that memory usage is optimized when weights are frozen for MXFP8."""
    dim = 1024
    linear = Linear(dim, dim, bias=False)
    x = torch.randn(dim, dim, requires_grad=True, device="cuda")

    # Freeze weights
    linear.weight.requires_grad = False

    # Forward and backward pass with FP8
    with fp8_autocast():
        o = linear(x)
        g_o = torch.randn_like(o)

    max_memory_before_backward = torch.cuda.max_memory_allocated()
    o.backward(g_o)
    max_memory_after_backward = torch.cuda.max_memory_allocated()

    memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
    assert memory_diff < 5.5, (
        f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
        " grad_output should be quantized only columnwise."
    )


1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
@pytest.mark.parametrize(
    "module_name",
    ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
    "quantization",
    (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
    module_name: str,
    quantization: Optional[str],
) -> None:
    """Test heuristics for initializing quantized weights"""
1037
1038
    if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
        pytest.skip("Quantized model parameters are not supported in debug mode.")
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100

    # Tensor dimensions
    sequence_length = 32
    hidden_size = 32

    # Skip invalid configurations
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    # Construct quantization recipe
    with_quantization = quantization not in (None, "None")
    quantization_recipe = None
    if quantization == "fp8_delayed_scaling":
        quantization_recipe = recipe.DelayedScaling()
    elif quantization == "fp8_current_scaling":
        quantization_recipe = recipe.Float8CurrentScaling()
    elif quantization == "mxfp8":
        quantization_recipe = recipe.MXFP8BlockScaling()

    # Construct module
    module = None
    with torch.no_grad():
        with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
            if module_name == "Linear":
                module = Linear(hidden_size, hidden_size)
            elif module_name == "LayerNormLinear":
                module = LayerNormLinear(hidden_size, hidden_size)
            elif module_name == "LayerNormMLP":
                module = LayerNormMLP(hidden_size, hidden_size)
            elif module_name == "GroupedLinear":
                module = GroupedLinear(1, hidden_size, hidden_size)
            elif module_name == "ops.Linear":
                module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

    def check_weights():
        """Helper function to check that weight parameters have expected data"""
        for param in module.parameters():
            if isinstance(param, Float8Tensor):
                assert param._data is not None, "Missing FP8 data"
                assert (
                    param._transpose is None and param._transpose_invalid
                ), "FP8 transpose is not expected for inference"
            if isinstance(param, MXFP8Tensor):
                assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
                assert (
                    param._columnwise_data is None
                ), "Column-wise MXFP8 data is not expected for inference"

    # Check that modules have expected weights after initialization
    check_weights()

    # Check that modules have expected weights after forward pass
    with torch.inference_mode():
        x = torch.zeros(sequence_length, hidden_size, device="cuda")
        kwargs = {}
        if module_name == "GroupedLinear":
            kwargs["m_splits"] = [sequence_length]
        with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
            y = module(x, **kwargs)
    check_weights()