test_sanity.py 34.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
import torch
import pytest
11
import io
Przemek Tredak's avatar
Przemek Tredak committed
12

13
14
15
16
17
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
18
from transformer_engine.pytorch.utils import (
19
    get_device_compute_capability,
Przemek Tredak's avatar
Przemek Tredak committed
20
21
    init_method_normal,
    scaled_init_method_normal,
22
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
23
24
25
26
27
28
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
29
30
    RMSNorm,
    LayerNorm,
31
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
32
33
)
from transformer_engine.common import recipe
34
import transformer_engine_torch as tex
35
36
37
38
39
40
41
from transformer_engine.pytorch.cpp_extensions import (
    gemm,
    fp8_gemm,
    gelu,
    cast_to_fp8,
    cast_from_fp8,
)
42
43
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
Przemek Tredak's avatar
Przemek Tredak committed
44

45
# Only run FP8 tests on H100.
46
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
47

Przemek Tredak's avatar
Przemek Tredak committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

67

68
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
69
class ModelConfig:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
85

86

Przemek Tredak's avatar
Przemek Tredak committed
87
model_configs = {
88
89
90
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
91
    "large": ModelConfig(1, 128, 2, 512, 4, 128),
Przemek Tredak's avatar
Przemek Tredak committed
92
93
94
}

fp8_recipes = [
95
    None,  # Handles non-FP8 case
96
97
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
Przemek Tredak's avatar
Przemek Tredak committed
98
    recipe.DelayedScaling(
99
100
101
        margin=0,
        fp8_format=recipe.Format.E4M3,
        override_linear_precision=(False, False, True),
Przemek Tredak's avatar
Przemek Tredak committed
102
103
    ),
    recipe.DelayedScaling(
104
105
106
107
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="most_recent",
Przemek Tredak's avatar
Przemek Tredak committed
108
109
    ),
    recipe.DelayedScaling(
110
111
112
113
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="max",
Przemek Tredak's avatar
Przemek Tredak committed
114
115
    ),
    recipe.DelayedScaling(
116
117
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
118
119
120
121
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
122
123
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
124
125
126
127
128
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

129
param_types = [torch.float32, torch.float16]
130
if is_bf16_compatible():  # bf16 requires sm_80 or higher
131
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
132

133
all_boolean = [True, False]
134
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
135

136
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
137
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
138

139

schetlur-nv's avatar
schetlur-nv committed
140
141
142
143
144
def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


145
146
147
148
149
150
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


151
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
152
153
154
155
156
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
157
158
159
160
161
162
163
164
165
166
167
    static_input = torch.randn(
        config.seq_len,
        config.batch_size,
        config.hidden_size,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
    )
    static_target = torch.randn(
        config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
    )
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
182
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
183
184
185
186
187
188
189
190
191
192
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
193
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


208
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
209
    te_inp_hidden_states = torch.randn(
210
211
212
213
214
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
215
    te_inp_hidden_states.retain_grad()
216
217
218
219
220
221
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
222
223
224
225

    if skip_wgrad:
        _disable_wgrads(block)

226
227
228
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
229
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
230
231
232
233
234
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

235
236
237
238
239
240
241
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


242
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
243
    te_inp_hidden_states = torch.randn(
244
245
246
247
248
249
250
251
252
253
254
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
255
256
257
258
259
260
261
262
263
264
265
266

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
267
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
268
269
270
271
272
273
274
275
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
276
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
277

Przemek Tredak's avatar
Przemek Tredak committed
278

279
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
280
    te_inp_hidden_states = torch.randn(
281
282
283
284
285
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
286
287
288
289

    if skip_wgrad:
        _disable_wgrads(block)

290
291
292
293
294
295
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

296
    use_fp8 = fp8_recipe is not None
297
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
298
        te_out = block(te_inp_hidden_states)
299
    te_out = sync_function(te_out)
300
301
302
303
304
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


305
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
306
    te_inp_hidden_states = torch.randn(
307
308
309
310
311
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
312

313
314
315
316
317
318
    te_inp_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
319
320
321
322

    if skip_wgrad:
        _disable_wgrads(block)

323
324
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
325
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
326
327
328
329
330
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


331
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
332
    te_inp_hidden_states = torch.randn(
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
351
352
353
354

    if skip_wgrad:
        _disable_wgrads(block)

355
356
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
357
        te_out = block(
358
359
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
360
361
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
362
363
364
365
366
367
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


368
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
369
370
371
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
372
    te_inp = torch.randn(
373
374
375
376
377
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
378
379
380
381

    if skip_wgrad:
        _disable_wgrads(block)

382
383
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
384
385
386
387
388
389
390
391
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


392
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
393
394
395
396
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
397
398
399
400
        (config.seq_len, config.batch_size, config.hidden_size),
        device="cuda",
        requires_grad=True,
    )
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
418
@pytest.mark.parametrize("model", ["small", "weird"])
419
420
421
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
422
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
423
424
425
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

426
    block = module(config.hidden_size).to(dtype=torch.float32).cuda()
427
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
428
429


Przemek Tredak's avatar
Przemek Tredak committed
430
431
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
432
@pytest.mark.parametrize("model", ["small", "weird"])
433
434
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
435
@pytest.mark.parametrize("skip_dgrad", all_boolean)
436
@pytest.mark.parametrize("normalization", all_normalizations)
437
438
439
def test_sanity_layernorm_linear(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
):
440
441
442
443
444
445
446
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
447

Przemek Tredak's avatar
Przemek Tredak committed
448
449
450
    sigma = 0.023
    init_method = init_method_normal(sigma)

451
452
453
454
455
456
457
458
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
459
    )
460
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
461
462
463
464


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
465
@pytest.mark.parametrize("model", ["small", "weird"])
466
@pytest.mark.parametrize("skip_wgrad", all_boolean)
467
@pytest.mark.parametrize("skip_dgrad", all_boolean)
468
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
469
470
    config = model_configs[model]

471
472
473
474
475
476
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
477
478
479
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

480
481
482
483
484
485
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
486
    )
487
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
488
489


490
491
492
493
494
495
496
497
498
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
499
    num_tokens = bs * config.seq_len
500
501
502
503
504
505
506
507
508

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params):
509
510
511
        te_linear = Linear(
            config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
        ).cuda()
512
513
514
515
516
517
518
519
520
521
522

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
523
524
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
525
@pytest.mark.parametrize("model", ["small", "weird"])
526
527
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
528
@pytest.mark.parametrize("skip_dgrad", all_boolean)
529
@pytest.mark.parametrize("activation", all_activations)
530
@pytest.mark.parametrize("normalization", all_normalizations)
531
532
533
def test_sanity_layernorm_mlp(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
):
534
535
536
537
538
539
540
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
541

Przemek Tredak's avatar
Przemek Tredak committed
542
543
544
545
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

546
547
548
549
550
551
552
553
554
555
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
556
    )
557
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
558
559
560
561


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
562
@pytest.mark.parametrize("model", ["small"])
563
564
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
565
@pytest.mark.parametrize("bias", all_boolean)
566
@pytest.mark.parametrize("activation", all_activations)
567
@pytest.mark.parametrize("normalization", all_normalizations)
568
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
569
@pytest.mark.parametrize("cpu_offload", all_boolean)
570
571
572
573
574
575
576
577
578
579
580
581
def test_sanity_gpt(
    dtype,
    fp8_recipe,
    model,
    skip_wgrad,
    zero_centered_gamma,
    bias,
    activation,
    normalization,
    parallel_attention_mlp,
    cpu_offload,
):
582
583
584
585
586
587
588
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
589

Przemek Tredak's avatar
Przemek Tredak committed
590
591
592
593
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
612
613
    )

614
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
615
616
617
618
619
620


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
621
622
            margin=0,
            fp8_format=recipe.Format.E4M3,
623
624
625
626
627
628
629
630
631
632
633
634
635
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
636
        cpu_offload=False,
637
    )
Przemek Tredak's avatar
Przemek Tredak committed
638
639
640
641


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
642
@pytest.mark.parametrize("model", ["small"])
643
644
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
645
@pytest.mark.parametrize("normalization", all_normalizations)
646
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
647
648
649
650
651
652
653
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
654

Przemek Tredak's avatar
Przemek Tredak committed
655
656
657
658
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
        zero_centered_gamma=zero_centered_gamma,
        self_attn_mask_type="padding",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
675
676
    )

677
678
679
680
681
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
682
683
        margin=0,
        fp8_format=recipe.Format.E4M3,
684
685
686
687
688
689
690
691
692
693
694
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
695
696
697
698


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
699
@pytest.mark.parametrize("model", ["small"])
700
701
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
702
@pytest.mark.parametrize("normalization", all_normalizations)
703
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
704
705
706
707
708
709
710
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
711

Przemek Tredak's avatar
Przemek Tredak committed
712
713
714
715
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
732
733
    )

734
735
736
737
738
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
739
740
        margin=0,
        fp8_format=recipe.Format.E4M3,
741
742
743
744
745
746
747
748
749
750
751
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
752
753
754
755


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
756
@pytest.mark.parametrize("model", ["small"])
757
@pytest.mark.parametrize("skip_wgrad", all_boolean)
758
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
759
760
    config = model_configs[model]

761
762
763
764
765
766
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
767
768
769
770
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

771
772
773
774
775
776
777
778
779
780
781
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
782
783
    )

784
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
785
786
787
788


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
789
@pytest.mark.parametrize("model", ["small"])
790
@pytest.mark.parametrize("skip_wgrad", all_boolean)
791
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
792
793
    config = model_configs[model]

794
795
796
797
798
799
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
800
801
802
803
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

804
805
806
807
808
809
810
811
812
813
814
815
816
817
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
818
819
    )

820
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
821
822
823
824


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
825
@pytest.mark.parametrize("model", ["small"])
826
@pytest.mark.parametrize("skip_wgrad", all_boolean)
827
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
828
829
    config = model_configs[model]

830
831
832
833
834
835
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
836
837
838
839
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

840
841
842
843
844
845
846
847
848
849
850
851
852
853
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
854
855
    )

856
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
857
858
859
860


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
861
@pytest.mark.parametrize("model", ["small"])
862
863
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
864
865
866
def test_sanity_gradient_accumulation_fusion(
    dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
867
868
    config = model_configs[model]

869
870
871
872
873
874
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

875
876
877
878
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
895
896
    )

897
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
898
899
900
901


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
902
@pytest.mark.parametrize("model", ["small"])
903
904
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
905
@pytest.mark.parametrize("normalization", all_normalizations)
906
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
907
908
909
910
911
912
913
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
914
915
916
917
918

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        normalization=normalization,
        device="cuda",
935
936
    )

937
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
938

939

940
def test_model_multiple_cast():
941
942
    a = torch.zeros((16, 16), device="cuda")
    m = Linear(16, 32)
943
944
945
946
947
948
949
950
951

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
952
953
954
955
956
957


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
958
    scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
959
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
960
    weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
961

962
    _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace())
963
964
965
966
967
968
969
970
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
971
    scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype)
972
973
974
975
976

    fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
    fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT

    nb_inp_scales, nb_weight_scales = 1, N
977
    scale_factor = 1.0
978
979
980
981
982
983
    meta_inp = create_meta(scale_factor, nb_inp_scales)
    meta_weight = create_meta(scale_factor, nb_weight_scales)
    inp_type = tex.DType.kFloat8E4M3
    weights_type = tex.DType.kFloat8E4M3
    outp_type = datatype

984
    scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type)
985
986
987
    inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
    _, _ = fp8_gemm(
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
        weight_fp8,
        meta_weight.scale_inv,
        fp8_tensor_weight,
        inp_type,
        inp_fp8,
        meta_inp.scale_inv,
        fp8_tensor_inp,
        weights_type,
        outp_type,
        get_workspace(),
        bias=None,
        use_bias=False,
        use_split_accumulator=False,
    )
1002
    torch.cuda.synchronize()
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
    config = model_configs[model]
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=True,
        fp8_mha=False,
    )
    hidden_states = torch.randn(
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    with fp8_model_init(enabled=True):
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            fuse_qkv_params=True,
            params_dtype=dtype,
            device="cuda",
        )
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        output = block(hidden_states, is_first_microbatch=True)
        loss = output.sum()
        loss.backward()

    # call state_dict()
    sd = block.state_dict()

    # check core_attention._extra_state
    attn_extra_state = sd["self_attention.core_attention._extra_state"]
    attn_extra_state.seek(0)
    attn_extra_state = torch.load(attn_extra_state, map_location="cuda")

    # add random core_attention.fused_attention._extra_state
    # it should not be loaded or cause any 'unexpected key' errors
    random_state = {"a": 1, "b": 2}
    fused_attn_extra_state = io.BytesIO()
    torch.save(random_state, fused_attn_extra_state)
    sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state

    # save checkpoint
    path = "./checkpoint.pt"
    torch.save(sd, path)

    # reinit the model
    del block
    with fp8_model_init(enabled=True):
        block_new = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            fuse_qkv_params=True,
            params_dtype=dtype,
            device="cuda",
        )
    FP8GlobalStateManager.reset()

    # load from checkpoint
    block_new.load_state_dict(torch.load(path))

    # check state_dict
    sd_new = block_new.state_dict()
    attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
    attn_extra_state_new.seek(0)
    attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
    for k, v in attn_extra_state_new.items():
        if k != "extra_fp8_variables":
            assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
        else:
            for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
                assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"