test_sanity.py 34.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
import torch
import pytest
11
import io
Przemek Tredak's avatar
Przemek Tredak committed
12

13
14
15
16
17
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
18
from transformer_engine.pytorch.utils import (
19
    get_device_compute_capability,
Przemek Tredak's avatar
Przemek Tredak committed
20
21
    init_method_normal,
    scaled_init_method_normal,
22
    is_bf16_compatible,
23
    get_cudnn_version,
Przemek Tredak's avatar
Przemek Tredak committed
24
25
26
27
28
29
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
30
31
    RMSNorm,
    LayerNorm,
32
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
33
34
)
from transformer_engine.common import recipe
35
import transformer_engine_torch as tex
36
37
38
39
40
41
42
from transformer_engine.pytorch.cpp_extensions import (
    gemm,
    fp8_gemm,
    gelu,
    cast_to_fp8,
    cast_from_fp8,
)
43
44
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
Przemek Tredak's avatar
Przemek Tredak committed
45

46
# Only run FP8 tests on H100.
47
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
48

Przemek Tredak's avatar
Przemek Tredak committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

68

69
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
70
class ModelConfig:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
86

87

Przemek Tredak's avatar
Przemek Tredak committed
88
model_configs = {
89
90
91
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
92
    "large": ModelConfig(1, 128, 2, 512, 4, 128),
Przemek Tredak's avatar
Przemek Tredak committed
93
94
95
}

fp8_recipes = [
96
    None,  # Handles non-FP8 case
97
98
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
Przemek Tredak's avatar
Przemek Tredak committed
99
    recipe.DelayedScaling(
100
101
102
        margin=0,
        fp8_format=recipe.Format.E4M3,
        override_linear_precision=(False, False, True),
Przemek Tredak's avatar
Przemek Tredak committed
103
104
    ),
    recipe.DelayedScaling(
105
106
107
108
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="most_recent",
Przemek Tredak's avatar
Przemek Tredak committed
109
110
    ),
    recipe.DelayedScaling(
111
112
113
114
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="max",
Przemek Tredak's avatar
Przemek Tredak committed
115
116
    ),
    recipe.DelayedScaling(
117
118
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
119
120
121
122
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
123
124
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
125
126
127
128
129
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

130
param_types = [torch.float32, torch.float16]
131
if is_bf16_compatible():  # bf16 requires sm_80 or higher
132
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
133

134
all_boolean = [True, False]
135
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
136

137
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
138
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
139

140

schetlur-nv's avatar
schetlur-nv committed
141
142
143
144
145
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


146
147
148
149
150
151
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


152
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
153
154
155
156
157
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
158
159
160
161
162
163
164
165
166
167
168
    static_input = torch.randn(
        config.seq_len,
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
    )
    static_target = torch.randn(
        config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
    )
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
183
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
184
185
186
187
188
189
190
191
192
193
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
194
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


209
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
210
    te_inp_hidden_states = torch.randn(
211
212
213
214
215
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
216
    te_inp_hidden_states.retain_grad()
217
218
219
220
221
222
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
223
224
225
226

    if skip_wgrad:
        _disable_wgrads(block)

227
228
229
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
230
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
231
232
233
234
235
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

236
237
238
239
240
241
242
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


243
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
244
    te_inp_hidden_states = torch.randn(
245
246
247
248
249
250
251
252
253
254
255
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
256
257
258
259
260
261
262
263
264
265
266
267

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
268
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
269
270
271
272
273
274
275
276
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
277
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
278

Przemek Tredak's avatar
Przemek Tredak committed
279

280
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
281
    te_inp_hidden_states = torch.randn(
282
283
284
285
286
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
287
288
289
290

    if skip_wgrad:
        _disable_wgrads(block)

291
292
293
294
295
296
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

297
    use_fp8 = fp8_recipe is not None
298
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
299
        te_out = block(te_inp_hidden_states)
300
    te_out = sync_function(te_out)
301
302
303
304
305
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


306
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
307
    te_inp_hidden_states = torch.randn(
308
309
310
311
312
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
313

314
315
316
317
318
319
    te_inp_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
320
321
322
323

    if skip_wgrad:
        _disable_wgrads(block)

324
325
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
326
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
327
328
329
330
331
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


332
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
333
    te_inp_hidden_states = torch.randn(
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
352
353
354
355

    if skip_wgrad:
        _disable_wgrads(block)

356
357
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
358
        te_out = block(
359
360
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
361
362
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
363
364
365
366
367
368
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


369
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
370
371
372
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
373
    te_inp = torch.randn(
374
375
376
377
378
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
379
380
381
382

    if skip_wgrad:
        _disable_wgrads(block)

383
384
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
385
386
387
388
389
390
391
392
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


393
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
394
395
396
397
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
398
399
400
401
        (config.seq_len, config.batch_size, config.hidden_size),
        device="cuda",
        requires_grad=True,
    )
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
419
@pytest.mark.parametrize("model", ["small", "weird"])
420
421
422
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
423
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
424
425
426
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

427
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
428
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
429
430


Przemek Tredak's avatar
Przemek Tredak committed
431
432
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
433
@pytest.mark.parametrize("model", ["small", "weird"])
434
435
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
436
@pytest.mark.parametrize("skip_dgrad", all_boolean)
437
@pytest.mark.parametrize("normalization", all_normalizations)
438
439
440
def test_sanity_layernorm_linear(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
):
441
442
443
444
445
446
447
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
448

Przemek Tredak's avatar
Przemek Tredak committed
449
450
451
    sigma = 0.023
    init_method = init_method_normal(sigma)

452
453
454
455
456
457
458
459
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
460
    )
461
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
462
463
464
465


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
466
@pytest.mark.parametrize("model", ["small", "weird"])
467
@pytest.mark.parametrize("skip_wgrad", all_boolean)
468
@pytest.mark.parametrize("skip_dgrad", all_boolean)
469
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
470
471
    config = model_configs[model]

472
473
474
475
476
477
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
478
479
480
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

481
482
483
484
485
486
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
487
    )
488
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
489
490


491
492
493
494
495
496
497
498
499
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
500
    num_tokens = bs * config.seq_len
501
502
503
504
505
506
507
508
509

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params):
510
511
512
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
513
514
515
516
517
518
519
520
521
522
523

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
524
525
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
526
@pytest.mark.parametrize("model", ["small", "weird"])
527
528
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
529
@pytest.mark.parametrize("skip_dgrad", all_boolean)
530
@pytest.mark.parametrize("activation", all_activations)
531
@pytest.mark.parametrize("normalization", all_normalizations)
532
533
534
def test_sanity_layernorm_mlp(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
):
535
536
537
538
539
540
541
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
542

Przemek Tredak's avatar
Przemek Tredak committed
543
544
545
546
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

547
548
549
550
551
552
553
554
555
556
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
557
    )
558
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
559
560
561
562


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
563
@pytest.mark.parametrize("model", ["small"])
564
565
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
566
@pytest.mark.parametrize("bias", all_boolean)
567
@pytest.mark.parametrize("activation", all_activations)
568
@pytest.mark.parametrize("normalization", all_normalizations)
569
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
570
@pytest.mark.parametrize("cpu_offload", all_boolean)
571
572
573
574
575
576
577
578
579
580
581
582
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
    cpu_offload,
):
583
584
585
586
587
588
589
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
590

Przemek Tredak's avatar
Przemek Tredak committed
591
592
593
594
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
613
614
    )

615
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
616
617
618
619
620
621


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
622
623
            margin=0,
            fp8_format=recipe.Format.E4M3,
624
625
626
627
628
629
630
631
632
633
634
635
636
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
637
        cpu_offload=False,
638
    )
Przemek Tredak's avatar
Przemek Tredak committed
639
640
641
642


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
643
@pytest.mark.parametrize("model", ["small"])
644
645
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
646
@pytest.mark.parametrize("normalization", all_normalizations)
647
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
648
649
650
651
652
653
654
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
655

Przemek Tredak's avatar
Przemek Tredak committed
656
657
658
659
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
        zero_centered_gamma=zero_centered_gamma,
        self_attn_mask_type="padding",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
676
677
    )

678
679
680
681
682
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
683
684
        margin=0,
        fp8_format=recipe.Format.E4M3,
685
686
687
688
689
690
691
692
693
694
695
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
696
697
698
699


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
700
@pytest.mark.parametrize("model", ["small"])
701
702
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
703
@pytest.mark.parametrize("normalization", all_normalizations)
704
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
705
706
707
708
709
710
711
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
712

Przemek Tredak's avatar
Przemek Tredak committed
713
714
715
716
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
733
734
    )

735
736
737
738
739
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
740
741
        margin=0,
        fp8_format=recipe.Format.E4M3,
742
743
744
745
746
747
748
749
750
751
752
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
753
754
755
756


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
757
@pytest.mark.parametrize("model", ["small"])
758
@pytest.mark.parametrize("skip_wgrad", all_boolean)
759
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
760
761
    config = model_configs[model]

762
763
764
765
766
767
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
768
769
770
771
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

772
773
774
775
776
777
778
779
780
781
782
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
783
784
    )

785
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
786
787
788
789


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
790
@pytest.mark.parametrize("model", ["small"])
791
@pytest.mark.parametrize("skip_wgrad", all_boolean)
792
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
793
794
    config = model_configs[model]

795
796
797
798
799
800
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
801
802
803
804
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

805
806
807
808
809
810
811
812
813
814
815
816
817
818
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
819
820
    )

821
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
822
823
824
825


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
826
@pytest.mark.parametrize("model", ["small"])
827
@pytest.mark.parametrize("skip_wgrad", all_boolean)
828
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
829
830
    config = model_configs[model]

831
832
833
834
835
836
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
837
838
839
840
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

841
842
843
844
845
846
847
848
849
850
851
852
853
854
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
855
856
    )

857
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
858
859
860
861


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
862
@pytest.mark.parametrize("model", ["small"])
863
864
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
865
866
867
def test_sanity_gradient_accumulation_fusion(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
868
869
    config = model_configs[model]

870
871
872
873
874
875
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

876
877
878
879
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
896
897
    )

898
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
899
900
901
902


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
903
@pytest.mark.parametrize("model", ["small"])
904
905
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
906
@pytest.mark.parametrize("normalization", all_normalizations)
907
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
908
909
910
911
912
913
914
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
915
916
917
918
919

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        normalization=normalization,
        device="cuda",
936
937
    )

938
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
939

940

941
def test_model_multiple_cast():
942
943
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
944
945
946
947
948
949
950
951
952

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
953
954
955
956
957
958


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
959
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
960
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
961
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
962

963
    _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace())
964
965
966
967
968
969
970
971
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
972
    scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype)
973
974
975
976
977

    fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
    fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT

    nb_inp_scales, nb_weight_scales = 1, N
978
    scale_factor = 1.0
979
980
981
982
983
984
    meta_inp = create_meta(scale_factor, nb_inp_scales)
    meta_weight = create_meta(scale_factor, nb_weight_scales)
    inp_type = tex.DType.kFloat8E4M3
    weights_type = tex.DType.kFloat8E4M3
    outp_type = datatype

985
    scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type)
986
987
988
    inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
    _, _ = fp8_gemm(
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
        weight_fp8,
        meta_weight.scale_inv,
        fp8_tensor_weight,
        inp_type,
        inp_fp8,
        meta_inp.scale_inv,
        fp8_tensor_inp,
        weights_type,
        outp_type,
        get_workspace(),
        bias=None,
        use_bias=False,
        use_split_accumulator=False,
    )
1003
    torch.cuda.synchronize()
1004
1005
1006
1007


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
1008
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
    config = model_configs[model]
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=True,
        fp8_mha=False,
    )
    hidden_states = torch.randn(
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    with fp8_model_init(enabled=True):
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            fuse_qkv_params=True,
            params_dtype=dtype,
            device="cuda",
        )
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        output = block(hidden_states, is_first_microbatch=True)
        loss = output.sum()
        loss.backward()

    # call state_dict()
    sd = block.state_dict()

    # check core_attention._extra_state
    attn_extra_state = sd["self_attention.core_attention._extra_state"]
    attn_extra_state.seek(0)
    attn_extra_state = torch.load(attn_extra_state, map_location="cuda")

    # add random core_attention.fused_attention._extra_state
    # it should not be loaded or cause any 'unexpected key' errors
    random_state = {"a": 1, "b": 2}
    fused_attn_extra_state = io.BytesIO()
    torch.save(random_state, fused_attn_extra_state)
    sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state

    # save checkpoint
    path = "./checkpoint.pt"
    torch.save(sd, path)

    # reinit the model
    del block
    with fp8_model_init(enabled=True):
        block_new = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            fuse_qkv_params=True,
            params_dtype=dtype,
            device="cuda",
        )
    FP8GlobalStateManager.reset()

    # load from checkpoint
    block_new.load_state_dict(torch.load(path))

    # check state_dict
    sd_new = block_new.state_dict()
    attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
    attn_extra_state_new.seek(0)
    attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
    for k, v in attn_extra_state_new.items():
        if k != "extra_fp8_variables":
            assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
        else:
            for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
                assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"