"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "0370afa2e55c4934e8d38fa1dfe22e7a8e64345a"
test_sanity.py 31.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
from dataclasses import dataclass
from typing import Optional
7
from contextlib import nullcontext
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
import torch
import pytest

12
13
14
15
16
from transformer_engine.pytorch.fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    fp8_model_init,
)
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
20
    is_bf16_compatible,
Przemek Tredak's avatar
Przemek Tredak committed
21
22
23
24
25
26
)
from transformer_engine.pytorch import (
    LayerNormLinear,
    Linear,
    LayerNormMLP,
    TransformerLayer,
27
28
    RMSNorm,
    LayerNorm,
29
    get_cpu_offload_context,
Przemek Tredak's avatar
Przemek Tredak committed
30
31
)
from transformer_engine.common import recipe
32
import transformer_engine_torch as tex
33
34
35
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
Przemek Tredak's avatar
Przemek Tredak committed
36

37
# Only run FP8 tests on H100.
38
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
39

Przemek Tredak's avatar
Przemek Tredak committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

def custom_amax_to_scale(
    amax: torch.Tensor,
    scale: torch.Tensor,
    fp8_max: torch.Tensor,
    recipe: recipe.DelayedScaling,
) -> torch.Tensor:
    """Custom func to test recipe."""
    sf = fp8_max / amax
    sf = torch.where(amax > 0.0, sf, scale)
    sf = torch.where(torch.isfinite(amax), sf, scale)

    return sf


def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
    """Custom func to test recipe."""
    return torch.min(amax_history, dim=0).values

59

60
@dataclass
Przemek Tredak's avatar
Przemek Tredak committed
61
class ModelConfig:
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    """Transformer model configuration"""

    num_layers: int
    seq_len: int
    batch_size: int
    hidden_size: int
    num_attention_heads: int
    kv_channels: Optional[int] = None

    def is_fp8_supported(self):
        if self.seq_len * self.batch_size % 16:
            return False
        if self.hidden_size % 16:
            return False
        return True
Przemek Tredak's avatar
Przemek Tredak committed
77
78

model_configs = {
79
80
81
    "126m": ModelConfig(12, 2048, 2, 768, 12),
    "small": ModelConfig(2, 32, 2, 64, 2),
    "weird": ModelConfig(2, 37, 3, 69, 3),
Przemek Tredak's avatar
Przemek Tredak committed
82
83
84
}

fp8_recipes = [
85
    None, # Handles non-FP8 case
86
87
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
    recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
Przemek Tredak's avatar
Przemek Tredak committed
88
    recipe.DelayedScaling(
89
90
91
        margin=0,
        fp8_format=recipe.Format.E4M3,
        override_linear_precision=(False, False, True),
Przemek Tredak's avatar
Przemek Tredak committed
92
93
    ),
    recipe.DelayedScaling(
94
95
96
97
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="most_recent",
Przemek Tredak's avatar
Przemek Tredak committed
98
99
    ),
    recipe.DelayedScaling(
100
101
102
103
        margin=0,
        fp8_format=recipe.Format.E4M3,
        amax_history_len=16,
        amax_compute_algo="max",
Przemek Tredak's avatar
Przemek Tredak committed
104
105
    ),
    recipe.DelayedScaling(
106
107
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
108
109
110
111
        amax_history_len=16,
        amax_compute_algo=custom_amax_compute,
    ),
    recipe.DelayedScaling(
112
113
        margin=0,
        fp8_format=recipe.Format.E4M3,
Przemek Tredak's avatar
Przemek Tredak committed
114
115
116
117
118
        amax_history_len=16,
        scaling_factor_compute_algo=custom_amax_to_scale,
    ),
]

119
param_types = [torch.float32, torch.float16]
120
if is_bf16_compatible():  # bf16 requires sm_80 or higher
121
    param_types.append(torch.bfloat16)
Przemek Tredak's avatar
Przemek Tredak committed
122

123
all_boolean = [True, False]
124
batch_sizes_with_zero = [0, 1, 2]
Przemek Tredak's avatar
Przemek Tredak committed
125

126
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
127
all_normalizations = ["LayerNorm", "RMSNorm"]
schetlur-nv's avatar
schetlur-nv committed
128
129
130
131
132
133

def _disable_wgrads(block):
    for p in block.parameters():
        p.requires_grad = False


134
135
136
137
138
139
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


140
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
141
142
143
144
145
    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

    # Placeholders used for capture.
146
147
    static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

    use_fp8 = fp8_recipe is not None
    if skip_wgrad:
        _disable_wgrads(block)

    # Pre graph capture warmup in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            optimizer.zero_grad(set_to_none=True)
162
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
163
164
165
166
167
168
169
170
171
172
                out = block(static_input)
            loss = loss_fn(out, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture.
    g = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
173
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            static_output = block(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
        optimizer.step()

    # Fills the graph's input memory with new data to compute on
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    g.replay()

    torch.cuda.synchronize()


188
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
189
    te_inp_hidden_states = torch.randn(
190
191
192
193
194
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=torch.float32,
        device="cuda",
        requires_grad=True,
    )
195
    te_inp_hidden_states.retain_grad()
196
197
198
199
200
201
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
202
203
204
205

    if skip_wgrad:
        _disable_wgrads(block)

206
207
208
    use_fp8 = fp8_recipe is not None
    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
209
            te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
210
211
212
213
214
        loss = te_out.sum()

    loss.backward()
    torch.cuda.synchronize()

215
216
217
218
219
220
221
    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


222
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
223
    te_inp_hidden_states = torch.randn(
224
225
226
227
228
229
230
231
232
233
234
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
235
236
237
238
239
240
241
242
243
244
245
246

    if skip_wgrad:
        _disable_wgrads(block)

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
            p.main_grad = torch.zeros_like(p)

    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
247
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
248
249
250
251
252
253
254
255
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    for name, p in block.named_parameters():
        if "layer_norm_weight" in name:
            continue
        elif "weight" in name and p.requires_grad:
256
            assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
257

Przemek Tredak's avatar
Przemek Tredak committed
258

259
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
Przemek Tredak's avatar
Przemek Tredak committed
260
    te_inp_hidden_states = torch.randn(
261
262
263
264
265
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
266
267
268
269

    if skip_wgrad:
        _disable_wgrads(block)

270
271
272
273
274
275
    if cpu_offload:
        offload_context, sync_function = get_cpu_offload_context(enabled=True)
    else:
        offload_context = nullcontext()
        sync_function = lambda x: x

276
    use_fp8 = fp8_recipe is not None
277
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
278
        te_out = block(te_inp_hidden_states)
279
    te_out = sync_function(te_out)
280
281
282
283
284
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


285
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
286
    te_inp_hidden_states = torch.randn(
287
288
289
290
291
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
292

293
294
295
296
297
298
    te_inp_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
299
300
301
302

    if skip_wgrad:
        _disable_wgrads(block)

303
304
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
305
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
Przemek Tredak's avatar
Przemek Tredak committed
306
307
308
309
310
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


311
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
312
    te_inp_hidden_states = torch.randn(
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    te_inp_attn_mask = torch.randint(
        2,
        (1, 1, config.seq_len, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )

    enc_dec_attn_mask = torch.randint(
        2,
        (config.batch_size, 1, 1, config.seq_len),
        dtype=torch.bool,
        device="cuda",
    )
schetlur-nv's avatar
schetlur-nv committed
331
332
333
334

    if skip_wgrad:
        _disable_wgrads(block)

335
336
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
337
        te_out = block(
338
339
            te_inp_hidden_states,
            attention_mask=te_inp_attn_mask,
340
341
            encoder_output=te_inp_hidden_states,
            enc_dec_attn_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
342
343
344
345
346
347
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


348
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
349
350
351
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

Przemek Tredak's avatar
Przemek Tredak committed
352
    te_inp = torch.randn(
353
354
355
356
357
        (config.seq_len, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=not skip_dgrad,
    )
schetlur-nv's avatar
schetlur-nv committed
358
359
360
361

    if skip_wgrad:
        _disable_wgrads(block)

362
363
    use_fp8 = fp8_recipe is not None
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
Przemek Tredak's avatar
Przemek Tredak committed
364
365
366
367
368
369
370
371
        te_out = block(te_inp)
    if isinstance(te_out, tuple):
        te_out = te_out[0]
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()


372
def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
373
374
375
376
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")

    te_inp = torch.randn(
377
378
379
380
        (config.seq_len, config.batch_size, config.hidden_size),
        device="cuda",
        requires_grad=True,
    )
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    te_inp.retain_grad()

    with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
        te_out = block(te_inp)
        loss = te_out.sum()
    loss.backward()

    torch.cuda.synchronize()

    assert te_out.dtype == dtype, "AMP wrong output type."
    assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
    for name, p in block.named_parameters():
        if p.requires_grad:
            assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."


@pytest.mark.parametrize("dtype", param_types)
398
@pytest.mark.parametrize("model", ["small", "weird"])
399
400
401
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
402
def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normalization):
403
404
405
406
    config = model_configs[model]
    module = RMSNorm if normalization == "RMSNorm" else LayerNorm

    block = (
407
        module(config.hidden_size)
408
409
410
        .to(dtype=torch.float32)
        .cuda()
    )
411
    _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
412
413


Przemek Tredak's avatar
Przemek Tredak committed
414
415
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
416
@pytest.mark.parametrize("model", ["small", "weird"])
417
418
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
419
@pytest.mark.parametrize("skip_dgrad", all_boolean)
420
@pytest.mark.parametrize("normalization", all_normalizations)
421
def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
422
423
                                 zero_centered_gamma, skip_dgrad,
                                 normalization):
424
425
426
427
428
429
430
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
431

Przemek Tredak's avatar
Przemek Tredak committed
432
433
434
    sigma = 0.023
    init_method = init_method_normal(sigma)

435
436
437
438
439
440
441
442
    block = LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
443
    )
444
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
445
446
447
448


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
449
@pytest.mark.parametrize("model", ["small", "weird"])
450
@pytest.mark.parametrize("skip_wgrad", all_boolean)
451
@pytest.mark.parametrize("skip_dgrad", all_boolean)
452
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
Przemek Tredak's avatar
Przemek Tredak committed
453
454
    config = model_configs[model]

455
456
457
458
459
460
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
461
462
463
    sigma = 0.023
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

464
465
466
467
468
469
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=output_layer_init_method,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
470
    )
471
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
472
473


474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
    config = model_configs[model]
    ffn_hidden_size = 4 * config.hidden_size
    num_tokens = bs*config.seq_len

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

    use_fp8 = fp8_recipe is not None
    with fp8_model_init(enabled=use_fp8 and fp8_model_params):
        te_linear = (
            Linear(
                config.hidden_size,
                ffn_hidden_size,
                bias=use_bias,
                params_dtype=dtype
            )
            .cuda()
        )

    inp_hidden_states = torch.randn(
        num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
        out = te_linear(inp_hidden_states)
    loss = out.sum()
    loss.backward()
    assert out.shape == (num_tokens, ffn_hidden_size)


Przemek Tredak's avatar
Przemek Tredak committed
513
514
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
515
@pytest.mark.parametrize("model", ["small", "weird"])
516
517
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
518
@pytest.mark.parametrize("skip_dgrad", all_boolean)
519
@pytest.mark.parametrize("activation", all_activations)
520
@pytest.mark.parametrize("normalization", all_normalizations)
521
def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
522
523
                              zero_centered_gamma, skip_dgrad, activation,
                              normalization):
524
525
526
527
528
529
530
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
531

Przemek Tredak's avatar
Przemek Tredak committed
532
533
534
535
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

536
537
538
539
540
541
542
543
544
545
    block = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
546
    )
547
    _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
Przemek Tredak's avatar
Przemek Tredak committed
548
549
550
551


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
552
@pytest.mark.parametrize("model", ["small"])
553
554
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
ngoyal2707's avatar
ngoyal2707 committed
555
@pytest.mark.parametrize("bias", all_boolean)
556
@pytest.mark.parametrize("activation", all_activations)
557
@pytest.mark.parametrize("normalization", all_normalizations)
558
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
559
@pytest.mark.parametrize("cpu_offload", all_boolean)
560
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
561
                    zero_centered_gamma, bias, activation,
562
563
                    normalization, parallel_attention_mlp,
                    cpu_offload):
564
565
566
567
568
569
570
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
571

Przemek Tredak's avatar
Przemek Tredak committed
572
573
574
575
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        bias=bias,
        activation=activation,
        normalization=normalization,
        device="cuda",
        parallel_attention_mlp=parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
594
595
    )

596
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
597
598
599
600
601
602


def test_sanity_gpt_126m():
    fp8_recipe = None
    if fp8_available:
        fp8_recipe = recipe.DelayedScaling(
603
604
            margin=0,
            fp8_format=recipe.Format.E4M3,
605
606
607
608
609
610
611
612
613
614
615
616
617
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
    test_sanity_gpt(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=True,
        bias=True,
        activation="gelu",
        normalization="LayerNorm",
        parallel_attention_mlp=False,
618
        cpu_offload=False,
619
    )
Przemek Tredak's avatar
Przemek Tredak committed
620
621
622
623


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
624
@pytest.mark.parametrize("model", ["small"])
625
626
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
627
@pytest.mark.parametrize("normalization", all_normalizations)
628
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
629
                     normalization):
630
631
632
633
634
635
636
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
637

Przemek Tredak's avatar
Przemek Tredak committed
638
639
640
641
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=True,
        output_layernorm=True,
        zero_centered_gamma=zero_centered_gamma,
        self_attn_mask_type="padding",
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
658
659
    )

660
661
662
663
664
    _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_bert_126m():
    fp8_recipe = recipe.DelayedScaling(
665
666
        margin=0,
        fp8_format=recipe.Format.E4M3,
667
668
669
670
671
672
673
674
675
676
677
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_bert(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
678
679
680
681


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
682
@pytest.mark.parametrize("model", ["small"])
683
684
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
685
@pytest.mark.parametrize("normalization", all_normalizations)
686
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
687
                   normalization):
688
689
690
691
692
693
694
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
695

Przemek Tredak's avatar
Przemek Tredak committed
696
697
698
699
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        layer_type="decoder",
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
716
717
    )

718
719
720
721
722
    _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad)


def test_sanity_T5_126m():
    fp8_recipe = recipe.DelayedScaling(
723
724
        margin=0,
        fp8_format=recipe.Format.E4M3,
725
726
727
728
729
730
731
732
733
734
735
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )
    test_sanity_T5(
        dtype=param_types[-1],
        fp8_recipe=fp8_recipe,
        model="126m",
        skip_wgrad=False,
        zero_centered_gamma=False,
        normalization="LayerNorm",
    )
Przemek Tredak's avatar
Przemek Tredak committed
736
737
738
739


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
740
@pytest.mark.parametrize("model", ["small"])
741
@pytest.mark.parametrize("skip_wgrad", all_boolean)
742
def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
743
744
    config = model_configs[model]

745
746
747
748
749
750
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
751
752
753
754
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

755
756
757
758
759
760
761
762
763
764
765
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=torch.float32,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
766
767
    )

768
    _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad)
Przemek Tredak's avatar
Przemek Tredak committed
769
770
771
772


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
773
@pytest.mark.parametrize("model", ["small"])
774
@pytest.mark.parametrize("skip_wgrad", all_boolean)
775
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
776
777
    config = model_configs[model]

778
779
780
781
782
783
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
784
785
786
787
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

788
789
790
791
792
793
794
795
796
797
798
799
800
801
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        drop_path_rate=1.0,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
802
803
    )

804
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
Przemek Tredak's avatar
Przemek Tredak committed
805
806
807
808


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
809
@pytest.mark.parametrize("model", ["small"])
810
@pytest.mark.parametrize("skip_wgrad", all_boolean)
811
def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
Przemek Tredak's avatar
Przemek Tredak committed
812
813
    config = model_configs[model]

814
815
816
817
818
819
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

Przemek Tredak's avatar
Przemek Tredak committed
820
821
822
823
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

824
825
826
827
828
829
830
831
832
833
834
835
836
837
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        fuse_qkv_params=True,
        device="cuda",
Przemek Tredak's avatar
Przemek Tredak committed
838
839
    )

840
    _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
841
842
843
844


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
845
@pytest.mark.parametrize("model", ["small"])
846
847
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
848
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
849
850
    config = model_configs[model]

851
852
853
854
855
856
    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")

857
858
859
860
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        fuse_wgrad_accumulation=True,
        device="cuda",
877
878
    )

879
    _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
880
881
882
883


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
884
@pytest.mark.parametrize("model", ["small"])
885
886
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
887
@pytest.mark.parametrize("normalization", all_normalizations)
888
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
889
                        normalization):
890
891
892
893
894
895
896
    config = model_configs[model]

    if fp8_recipe is not None:
        if not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if not config.is_fp8_supported():
            pytest.skip("Model config does not support FP8")
897
898
899
900
901

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.kv_channels,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        zero_centered_gamma=zero_centered_gamma,
        fuse_qkv_params=True,
        normalization=normalization,
        device="cuda",
918
919
    )

920
    _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
921
922

def test_model_multiple_cast():
923
    a = torch.zeros((16,16), device="cuda")
924
925
926
927
928
929
930
931
932
933
    m = Linear(16,32)

    y = m(a)
    assert y.dtype == torch.float32

    m.half()
    a = a.half()

    y2 = m(a)
    assert y2.dtype == torch.float16
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
    scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype)
    inp = torch.reshape(scratchpad[offset:-offset], (N, N))
    weight = torch.reshape(scratchpad[offset*2:], (N, N))

    _, _, _ = gemm(
        A=weight,
        B=inp,
        dtype=datatype,
        workspace=get_workspace())
    torch.cuda.synchronize()


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
    offset = 16
    scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype)

    fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
    fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT

    nb_inp_scales, nb_weight_scales = 1, N
    scale_factor = 1.
    meta_inp = create_meta(scale_factor, nb_inp_scales)
    meta_weight = create_meta(scale_factor, nb_weight_scales)
    inp_type = tex.DType.kFloat8E4M3
    weights_type = tex.DType.kFloat8E4M3
    outp_type = datatype

    scratchpad_fp8 = cast_to_fp8(
            scratchpad,
            meta_weight,
            fp8_tensor_inp,
            inp_type)
    inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
    weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
    _, _ = fp8_gemm(
            weight_fp8,
            meta_weight.scale_inv,
            fp8_tensor_weight,
            inp_type,
            inp_fp8,
            meta_inp.scale_inv,
            fp8_tensor_inp,
            weights_type,
            outp_type,
            get_workspace(),
            bias=None,
            use_bias=False,
            use_split_accumulator=False)
    torch.cuda.synchronize()