test_numerics.py 73.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import math
6
import os
7
from typing import Dict, List, Optional
8
import pytest
9
import copy
10
import random
11
12
13
14
15

import torch
import torch.nn as nn
from torch.nn import Parameter

16
17
18
19
20
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    fp8_autocast,
    fp8_model_init,
)
21
22
23
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
24
    attention_mask_func,
25
    is_bf16_compatible,
26
27
)
from transformer_engine.pytorch import (
28
29
30
31
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
32
    GroupedLinear,
33
34
35
36
    MultiheadAttention,
    RMSNorm,
    TransformerLayer,
    LayerNorm,
37
38
    Fp8Padding,
    Fp8Unpadding,
39
)
40
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
41
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
42
43
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
44
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
45
from transformer_engine.pytorch.utils import get_device_compute_capability
46
from transformer_engine.common import recipe
47
import transformer_engine_torch as tex
48

49
# Only run FP8 tests on supported devices.
50
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
51
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
52

53
sm_80plus = get_device_compute_capability() >= (8, 0)
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()


class ModelConfig:
    def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
74
    "small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
75
76
77
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

78
79
80
81
82
83
84
85
model_configs_inference = {
    # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

86
param_types = [torch.float32, torch.float16]
87
if is_bf16_compatible():  # bf16 requires sm_80 or higher
88
89
90
91
92
93
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

94
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
95

96
97
all_normalizations = ["LayerNorm", "RMSNorm"]

98
99
mask_types = ["causal", "no_mask"]

100
101
102
fp8_recipes = [
    recipe.MXFP8BlockScaling(),
    recipe.DelayedScaling(),
103
    recipe.Float8CurrentScaling(),
104
105
]

106

107
108
109
110
def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


111
112
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
    """Estimated numerical error for a datatype
113

114
    Based on tolerances for torch.testing.assert_close.
115

116
117
118
119
120
121
122
123
124
125
126
    """
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    raise ValueError(f"Unsuppored dtype ({dtype})")


def assert_allclose(
127
    l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
128
) -> bool:
129
130
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
131
    for i, (t1, t2) in enumerate(zip(l1, l2)):
132
133
134
135
        tols = dict(atol=atol)
        if rtol is not None:
            tols["rtol"] = rtol
        result = torch.allclose(t1, t2, **tols)
136
        if not result:
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            diff = torch.abs(t1 - t2)
            tol = atol + (rtol * torch.abs(t2))
            exceed_mask = diff > tol
            if exceed_mask.any():
                indices = torch.nonzero(exceed_mask, as_tuple=True)
                max_diff = diff[exceed_mask].max()
                max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
                max_location = [idx[max_idx].item() for idx in indices]
                msg = (
                    f"Outputs not close enough in tensor at idx={i}. "
                    f"Maximum difference at location {max_location} "
                    f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
                    f"(diff {max_diff.item()})."
                )
151
            raise AssertionError(msg)
152
153
154


def reset_rng_states() -> None:
155
    """revert back to initial RNG state."""
156
    torch.set_rng_state(_cpu_rng_state)
157
158
159
160
161
162
163
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()
164
165


166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs


class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
215
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
254
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
255
256

        # change view [b * np, sq, sk]
257
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

273

274
class TorchLayerNorm(nn.Module):
275
    def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        super().__init__()
        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.bias = nn.Parameter(torch.zeros(in_features))
        self.register_parameter("weight", self.weight)
        self.register_parameter("bias", self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weight if not self.zero_centered_gamma else 1 + self.weight
        w = w.to(torch.float32)
        b = self.bias.to(torch.float32)
        inp = x.to(torch.float32)
292
293
294
        out = torch.nn.functional.layer_norm(
            inp, (self.in_features,), weight=w, bias=b, eps=self.eps
        )
295
296
        return out.to(x.dtype)

297

298
299
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
300
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
301
302
303
304
        super().__init__()

        self.eps = eps
        self.in_features = in_features
305
        self.zero_centered_gamma = zero_centered_gamma
306

307
308
        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
309
310
311
        self.register_parameter("weight", self.weight)

    def forward(self, x):
312
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
313
314
        d_x = self.in_features

315
        rms_x2 = norm_x2 / d_x + self.eps
316
        r_rms_x = rms_x2 ** (-1.0 / 2)
317
        x_normed = x * r_rms_x
318

319
320
321
322
        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)
323

324

325
class TorchLayerNormLinear(nn.Module):
326
327
328
329
330
331
332
333
334
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float,
        bias: bool = True,
        normalization: str = "LayerNorm",
        zero_centered_gamma: bool = False,
    ):
335
        super().__init__()
336
        if normalization == "LayerNorm":
337
338
339
            self.layernorm = TorchLayerNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
340
        elif normalization == "RMSNorm":
341
342
343
            self.layernorm = TorchRMSNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
344
345
346
        else:
            raise RuntimeError("Unsupported normalization")

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.layernorm(x))


class TorchMHA(nn.Module):
    def __init__(self, hidden_size: int, num_attention_heads: int):
        super().__init__()
        self.mhsa = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=0.1,
            bias=True,
            batch_first=False,
        )

364
365
    def forward(self, x, attention_mask=None):
        output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
366
367
368
369
        if isinstance(output, tuple):
            output = output[0]
        return output

370

371
372
373
class TorchQuickGELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(1.702 * input)
374

375

376
377
378
379
class TorchSquaredRELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (input > 0) * input * input

380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
class TorchGroupedLinearWithPadding(nn.Module):

    def __init__(
        self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
    ) -> None:
        super().__init__()

        self.padding = Fp8Padding(num_gemms)
        self.linear_fn = GroupedLinear(
            num_gemms,
            in_features,
            out_features,
            bias=bias,
            params_dtype=params_dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        )
        self.unpadding = Fp8Unpadding(num_gemms)

        self.fp8 = fp8

    def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
        if self.fp8:
            orig_m_splits = m_splits
            inp, m_splits = self.padding(inp, m_splits)

        out = self.linear_fn(inp, m_splits)

        if self.fp8:
            out = self.unpadding(out, orig_m_splits)

        return out


415
416
417
418
419
420
421
422
423
_supported_act = {
    "geglu": nn.GELU(approximate="tanh"),
    "gelu": nn.GELU(approximate="tanh"),
    "reglu": nn.ReLU(),
    "relu": nn.ReLU(),
    "swiglu": nn.SiLU(),
    "qgelu": TorchQuickGELU(),
    "srelu": TorchSquaredRELU(),
}
424

425

426
427
428
429
430
431
432
class TorchGLU(nn.Module):
    def __init__(self, activation: str):
        super().__init__()
        self.act = _supported_act[activation]

    def forward(self, x):
        shape = x.size(-1)
433
434
        a = x[..., : shape // 2]
        b = x[..., (shape // 2) :]
435
436
        a = self.act(a)
        return a * b
437

438

439
class TorchLayerNormMLP(nn.Module):
440
441
442
443
444
445
446
447
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        activation="gelu",
        normalization: str = "LayerNorm",
    ):
448
        super().__init__()
449
        if normalization == "LayerNorm":
450
            self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
451
        elif normalization == "RMSNorm":
452
            self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
453
454
        else:
            raise RuntimeError("Unsupported normalization")
455
        if "glu" in activation:
456
457
458
459
460
461
462
            fc1_output_features = 2 * ffn_hidden_size
            self.gelu = TorchGLU(activation)
        else:
            fc1_output_features = ffn_hidden_size
            self.gelu = _supported_act[activation]

        self.fc1 = nn.Linear(hidden_size, fc1_output_features)
463
        self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
464
465

    def forward(self, x):
466
467
        t = self.gelu(self.fc1(self.ln(x)))
        return self.fc2(t)
468
469
470


class TorchGPT(nn.Module):
471
472
473
    def __init__(
        self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
    ):
474
        super().__init__()
475
        self.ln = nn.LayerNorm(hidden_size, eps=eps)
476
        self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
477
        self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
478
        self.parallel_attention_mlp = parallel_attention_mlp
479
480
481
482

    def forward(
        self,
        x: torch.Tensor,
483
        attention_mask: Optional[torch.Tensor] = None,
484
    ) -> torch.Tensor:
485
        a = self.ln(x)
486
        b = self.causal_attn(a, attention_mask)
487
488
489
490
491
492
493
        if self.parallel_attention_mlp:
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
        else:
            x = x + nn.functional.dropout(b, p=0.1, training=self.training)
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(n, p=0.1, training=self.training)
494
495
496
        return x


497
498
499
def _test_e2e_selective_recompute(
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False
):
500
    reset_rng_states()
501
    FP8GlobalStateManager.reset()
502
503
504
505
506

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

507
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
523
524
525
        )

    te_inp_hidden_states = torch.randn(
526
527
528
529
530
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
531
532
533
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

534
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
535
536
        te_out = block(
            te_inp_hidden_states,
537
            attention_mask=te_inp_attn_mask,
538
            checkpoint_core_attention=recompute,
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
553
@pytest.mark.parametrize("model", ["126m"])
554
@pytest.mark.parametrize("fp8", all_boolean)
555
@pytest.mark.parametrize("recipe", fp8_recipes)
556
@pytest.mark.parametrize("fp8_model_params", all_boolean)
557
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
558
559
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
560
561
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
562

563
564
    config = model_configs[model]

565
    outputs = _test_e2e_selective_recompute(
566
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False
567
568
    )
    outputs_recompute = _test_e2e_selective_recompute(
569
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True
570
    )
571
572
573
574
575
576
577

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-4
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
578

579
580
581
582
583
584
585
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
586
587


588
def _test_e2e_full_recompute(
589
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True
590
):
591
592
593
    reset_rng_states()
    FP8GlobalStateManager.reset()

594
595
596
597
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

598
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
599
        block = TransformerLayer(
600
601
602
603
604
605
606
607
608
609
610
611
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
612
            fuse_qkv_params=True,
613
            device="cuda",
614
        )
615

616
    te_inp_hidden_states = torch.randn(
617
618
619
620
621
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=use_reentrant,
    )
622
623
    if use_reentrant:
        te_inp_hidden_states.retain_grad()
624
625
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

626
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
627
628
629
630
631
632
        if recompute:
            te_out = te_checkpoint(
                block,
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
633
634
635
                distribute_saved_activations=False,
                tp_group=None,
                use_reentrant=use_reentrant,
636
637
638
639
640
641
642
643
644
645
646
            )
        else:
            te_out = block(
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

647
648
649
650
651
652
    outputs = [te_out]
    names = ["output"]
    if use_reentrant:
        outputs.append(te_inp_hidden_states.grad)
        names.append("input")
    for name, p in block.named_parameters():
653
654
        if p.requires_grad:
            outputs.append(p.grad)
655
656
657
            names.append(name)

    return outputs, names
658
659
660
661


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
662
@pytest.mark.parametrize("model", ["126m"])
663
@pytest.mark.parametrize("fp8", all_boolean)
664
@pytest.mark.parametrize("recipe", fp8_recipes)
665
@pytest.mark.parametrize("fp8_model_params", all_boolean)
666
@pytest.mark.parametrize("use_reentrant", all_boolean)
667
668
669
def test_gpt_full_activation_recompute(
    dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
670
671
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
672
673
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
674
675
676

    config = model_configs[model]

677
678
679
680
    if not use_reentrant:
        # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
        os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

681
    outputs, names = _test_e2e_full_recompute(
682
683
684
685
686
687
688
689
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=False,
        use_reentrant=use_reentrant,
690
691
    )
    outputs_recompute, _ = _test_e2e_full_recompute(
692
693
694
695
696
697
698
699
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=True,
        use_reentrant=use_reentrant,
700
    )
701
702
703
704
705

    if not use_reentrant:
        # Reset bias+GELU fusion flag to avoid contaminating other tests
        del os.environ["NVTE_BIAS_GELU_NVFUSION"]

706
707
708
709
710
711
712
713
714
715
716
717
718
    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-3
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
719
720
721
722
723
724


def _test_e2e_checkpointing_get_model(config, dtype):
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
725

726
727
728
729
730
731
732
733
734
735
736
737
738
739
    return TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        params_dtype=dtype,
        device="cuda",
740
741
742
743
744
745
746
    )


def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
    reset_rng_states()

    te_inp_hidden_states = torch.randn(
747
748
749
750
751
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
752
753
754
755
756
757
758
    te_inp_hidden_states.retain_grad()

    block = _test_e2e_checkpointing_get_model(config, dtype)

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
759
            None,
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
        )
        loss = te_out.sum()
        loss.backward()

    if checkpoint:
        # This process is necessary so that we can start afresh with
        # a new model while erasing all internal state to ensure that
        # loading from a checkpoint gives bitwise identical results.
        # Since gradients are being accumulated, it is important to
        # restore them post loading the checkpoint.
        torch.save(block.state_dict(), path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

777
778
779
780
        global _cpu_rng_state, _cuda_rng_state
        _cpu_rng_state = torch.get_rng_state()
        _cuda_rng_state = torch.cuda.get_rng_state()

781
782
        del block
        block = _test_e2e_checkpointing_get_model(config, dtype)
783
        block.load_state_dict(torch.load(path, weights_only=False))
784
        reset_rng_states()
785
786
787
788
789
790
791
792
793
794

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
795
            None,
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        )
        loss = te_out.sum()
        loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
814
@pytest.mark.parametrize("model", ["126m"])
815
816
817
def test_gpt_checkpointing(dtype, bs, model):
    config = model_configs[model]
    outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
818
    outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
819
820
821
822
823
824
825
826
827
828
829
830

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
831
832
833
834
835
836


def _test_e2e_gpt_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
837
838
839
840
841
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
842
843
844
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len)

845
    out = block(inp_hidden_states, attention_mask=inp_attn_mask)
846
847
848
849
850
851
852
853
854
855
856
857
858
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
859
@pytest.mark.parametrize("model", ["small"])
860
861
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
862
863
    config = model_configs[model]

864
865
866
867
868
869
870
871
872
873
874
875
876
    te_gpt = TransformerLayer(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        num_attention_heads=config.num_attention_heads,
        layernorm_epsilon=config.eps,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        params_dtype=dtype,
        fuse_qkv_params=True,
        qkv_weight_interleaved=False,
        parallel_attention_mlp=parallel_attention_mlp,
        device="cuda",
    ).eval()
877
878
879
880
881
882

    torch_gpt = (
        TorchGPT(
            config.hidden_size,
            config.eps,
            config.num_attention_heads,
883
            parallel_attention_mlp=parallel_attention_mlp,
884
885
886
887
888
889
890
891
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
892
        torch_gpt.ln.weight = Parameter(
893
894
            te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
        )
895
        torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
896
897
898
899
900
901
902
903
904
905
906
907
        torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
            te_gpt.self_attention.layernorm_qkv.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
            te_gpt.self_attention.layernorm_qkv.bias.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
            te_gpt.self_attention.proj.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
            te_gpt.self_attention.proj.bias.clone()
        )
908
909
910
911
912
913
        torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
        torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
        torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
        torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
        torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
        torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
914
915
916
917

    te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
    torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

918
919
920
921
922
923
    atol = {
        torch.float32: 5e-3,
        torch.half: 5e-2,
        torch.bfloat16: 1e-1,
    }

924
    # Check output.
925
926
927
928
929
930
931
932
933
934
935
936
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

    # Check gradients, only for small model
    if model == "small":
        atol[torch.float32] = 5e-2
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
937
938


939
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
940
941
942
    reset_rng_states()

    inp_hidden_states = torch.randn(
943
944
945
946
947
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
948
949
950
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None

951
952
953
954
955
956
    forward_kwargs = {}
    if te:
        forward_kwargs["attn_mask_type"] = mask_type
    forward_kwargs["attention_mask"] = inp_attn_mask

    out = block(inp_hidden_states, **forward_kwargs)
957
958
959
960
961
962
963
964
965
966
967
968
969
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
970
@pytest.mark.parametrize("model", ["small"])
971
972
973
974
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
    config = model_configs[model]

975
976
977
978
979
980
981
982
983
    te_mha = MultiheadAttention(
        config.hidden_size,
        config.num_attention_heads,
        fuse_qkv_params=True,
        params_dtype=dtype,
        qkv_weight_interleaved=False,
        input_layernorm=False,
        device="cuda",
    ).eval()
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001

    torch_mha = (
        TorchMHA(
            config.hidden_size,
            config.num_attention_heads,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
        torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
        torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
        torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

1002
1003
    te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
    torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
1004
1005
1006
1007
1008
1009
1010

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
    # Check gradients, only for small model
    if model == "small":
        atol = {
            torch.float32: 5e-2,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1026

1027
1028
1029
1030
def _test_granular_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
1031
1032
1033
1034
1035
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
    inp_hidden_states.retain_grad()

    out = block(inp_hidden_states)
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


1050
1051
1052
def _test_dpa_accuracy(block, bs, dtype, config):
    reset_rng_states()

1053
1054
1055
    mask = torch.triu(
        torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
    )
1056
    query, key, value = [
1057
1058
1059
1060
1061
1062
1063
1064
        torch.randn(
            (config.seq_len, bs, config.num_attention_heads, config.embed),
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        for _ in range(3)
    ]
1065
1066
1067
1068
1069

    query.retain_grad()
    key.retain_grad()
    value.retain_grad()

1070
    out = block(query, key, value, attention_mask=mask)
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()

    return [out, query.grad, key.grad, value.grad]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1081
@pytest.mark.parametrize("model", ["126m"])
1082
1083
1084
1085
1086
1087
1088
def test_dpa_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_dpa = (
        DotProductAttention(
            config.num_attention_heads,
            config.embed,
1089
            attention_dropout=0.0,  # disable dropout, FU uses rng differently
1090
1091
1092
1093
1094
1095
1096
1097
        )
        .to(dtype=dtype)
        .cuda()
    )

    torch_dpa = (
        TorchDotProductAttention(
            config.embed,
1098
            0.0,  # dropout
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
        )
        .to(dtype=dtype)
        .cuda()
    )

    te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
    torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1113
1114
1115
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)

1116

1117
1118
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1119
@pytest.mark.parametrize("model", ["small"])
1120
1121
1122
def test_linear_accuracy(dtype, bs, model):
    config = model_configs[model]

1123
1124
1125
1126
1127
1128
1129
    te_linear = Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=True,
        params_dtype=dtype,
        device="cuda",
    ).eval()
1130

1131
1132
1133
1134
1135
1136
1137
    torch_linear = torch.nn.Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=True,
        device="cuda",
        dtype=dtype,
    ).eval()
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147

    # Share params
    with torch.no_grad():
        torch_linear.weight = Parameter(te_linear.weight.clone())
        torch_linear.bias = Parameter(te_linear.bias.clone())

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

    # Check output.
1148
1149
1150
1151
1152
1153
1154
1155
1156
    if model == "small":
        tolerance = 5e-3 if dtype == torch.float32 else 5e-2
        rtol = {
            torch.float32: 1.3e-6,
            torch.half: 1e-2,
            torch.bfloat16: 2e-2,
        }
        for te_output, torch_output in zip(te_outputs, torch_outputs):
            assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
1157

1158

1159
1160
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1161
@pytest.mark.parametrize("model", ["126m"])
1162
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
1163
1164
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
1165
1166
    config = model_configs[model]

1167
1168
1169
1170
1171
1172
1173
    te_rmsnorm = RMSNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1174
1175

    torch_rmsnorm = (
1176
        TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())

    te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

1189
1190
1191
1192
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1193
    }
1194
1195

    # Check output.
1196
1197
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    atol[torch.float32] = 2e-3
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1208

1209
1210
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1211
@pytest.mark.parametrize("model", ["126m"])
1212
1213
1214
1215
1216
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
    config = model_configs[model]

1217
1218
1219
1220
1221
1222
1223
    te_layernorm = LayerNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1224
1225

    torch_layernorm = (
1226
        TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
        torch_layernorm.bias = Parameter(te_layernorm.bias.clone())

    te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

1240
1241
1242
1243
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1244
    }
1245
1246

    # Check output.
1247
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
1248

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    atol[torch.float32] = 1e-4
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1259

1260
1261
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1262
@pytest.mark.parametrize("model", ["small"])
1263
@pytest.mark.parametrize("normalization", all_normalizations)
1264
1265
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
1266
1267
    config = model_configs[model]

1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
    te_ln_linear = LayerNormLinear(
        config.hidden_size,
        4 * config.hidden_size,
        config.eps,
        bias=True,
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1278
1279
1280
1281
1282
1283
1284

    torch_ln_linear = (
        TorchLayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
            bias=True,
1285
            normalization=normalization,
1286
            zero_centered_gamma=zero_centered_gamma,
1287
1288
1289
1290
1291
1292
1293
1294
1295
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
1296
1297
        if normalization != "RMSNorm":
            torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
1298
1299
1300
1301
1302
1303
        torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
        torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())

    te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

1304
1305
1306
1307
    atol = {
        torch.float32: 2.5e-4,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1308
    }
1309
1310
1311
1312
1313
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }
1314
1315

    # Check output.
1316
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1317

1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
    if model == "small":
        atol = {
            torch.float32: 1e-3,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-3,
            torch.half: 4e-2,
            torch.bfloat16: 4e-2,
        }
        # Check gradients
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1333

1334
1335
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1336
@pytest.mark.parametrize("model", ["small"])
1337
@pytest.mark.parametrize("activation", all_activations)
1338
1339
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
1340
1341
    config = model_configs[model]

1342
1343
1344
1345
1346
1347
1348
1349
    te_ln_mlp = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
    ).eval()
1350
1351
1352
1353
1354

    torch_ln_mlp = (
        TorchLayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
1355
            activation=activation,
1356
            normalization=normalization,
1357
1358
1359
1360
1361
1362
1363
1364
1365
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
1366
1367
        if normalization != "RMSNorm":
            torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
1368
1369
1370
1371
1372
1373
1374
1375
        torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
        torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
        torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
        torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())

    te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

1376
1377
1378
1379
1380
1381
    atol = {
        torch.float32: 2e-2,
        torch.half: 5e-2,
        torch.bfloat16: 5e-2,
    }

1382
1383
1384
1385
1386
1387
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }

1388
    # Check output.
1389
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401

    # Check gradients, only for small model
    rtol = {
        torch.float32: 1e-3,
        torch.half: 1e-2,
        torch.bfloat16: 4e-2,
    }
    atol[torch.half] = 2e-1
    atol[torch.bfloat16] = 2e-1
    if model == "small":
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
1402
1403


1404
1405
1406
def _test_grouped_linear_accuracy(
    block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

1419
    if num_gemms > 1:
1420
1421
1422
1423
1424
1425
1426
        split_size = 1
        if fp8:
            if recipe.delayed():
                split_size = 16
            if recipe.mxfp8():
                split_size = 128
        m = config.seq_len // split_size
1427
1428
1429
        dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
        dist.append(dist[-1])  # Manually add a zero
        m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
1430
        m_splits = m_splits * split_size
1431
1432
1433
        assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
    else:
        m_splits = torch.tensor([config.seq_len])
1434

1435
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        if isinstance(block, GroupedLinear):
            m_splits = m_splits * bs
            out = block(inp_hidden_states, m_splits.tolist())
        else:
            out = torch.cat(
                [
                    block[i](inp)
                    for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
                ]
            )
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
1453
1454
1455
1456
1457
            if getattr(p, "main_grad", None) is not None:
                outputs.append(p.main_grad)
                assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
            else:
                outputs.append(p.grad)
1458
1459
1460
1461
1462
1463
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1464
@pytest.mark.parametrize("model", ["126m"])
1465
@pytest.mark.parametrize("fp8", all_boolean)
1466
@pytest.mark.parametrize("recipe", fp8_recipes)
1467
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1468
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
1469
def test_grouped_linear_accuracy(
1470
1471
1472
1473
1474
1475
1476
1477
1478
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
    parallel_mode=None,
1479
):
1480
1481
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1482
1483
1484
1485
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if fp8 and recipe.mxfp8():  # TODO(ksivamani): debug mismatches
        pytest.skip("MXFP8 unsupported for grouped linear.")
1486
1487
    if fp8 and recipe.float8_current_scaling():
        pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
1488
1489
1490
1491
1492

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1493
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1494
1495
1496
1497
1498
1499
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=True,
            params_dtype=dtype,
1500
            parallel_mode=parallel_mode,
1501
            device="cuda",
1502
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1503
1504
1505
1506
1507
1508
1509
1510
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
                    bias=True,
                    params_dtype=dtype,
1511
                    parallel_mode=parallel_mode,
1512
                    device="cuda",
1513
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
            sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
1524
1525
1526
1527
            if fuse_wgrad_accumulation:
                weight_i = getattr(grouped_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
                sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
1528
1529

    outputs_ref = _test_grouped_linear_accuracy(
1530
        sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
1531
1532
    )
    outputs = _test_grouped_linear_accuracy(
1533
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
1534
1535
1536
1537
1538
1539
1540
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1541
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
1542
1543
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
1544
    """Split the tests to save CI time"""
1545
1546
1547
1548
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=6,
        bs=2,
1549
        model="126m",
1550
        fp8=True,
1551
        recipe=recipe,
1552
1553
        fp8_model_params=True,
        parallel_mode=parallel_mode,
1554
        fuse_wgrad_accumulation=True,
1555
1556
1557
    )


1558
1559
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_single_gemm(recipe):
1560
1561
1562
1563
1564
    """Split the tests to save CI time"""
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=1,
        bs=2,
1565
        model="126m",
1566
        fp8=True,
1567
        recipe=recipe,
1568
        fp8_model_params=True,
1569
        fuse_wgrad_accumulation=True,
1570
1571
1572
    )


1573
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640

    def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
        """Padding tensor shapes to multiples of 16."""
        padded_tokens_per_expert = [
            (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
        ]
        hidden_states = torch.split(hidden_states, tokens_per_expert)
        padded_hidden_states = []
        for hidden_state, actual_num_tokens, padded_num_tokens in zip(
            hidden_states, tokens_per_expert, padded_tokens_per_expert
        ):
            padded_hidden_states.append(hidden_state)
            if padded_num_tokens > actual_num_tokens:
                pad_tensor = torch.zeros(
                    padded_num_tokens - actual_num_tokens,
                    hidden_state.shape[1],
                    dtype=hidden_state.dtype,
                    device=hidden_state.device,
                )
                padded_hidden_states.append(pad_tensor)
        padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
        return padded_hidden_states, padded_tokens_per_expert

    def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
        inputmats = torch.split(
            padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
        )
        hidden_states = torch.cat(
            [
                grad_output_mat[: actual_tokens_per_expert[i]]
                for i, grad_output_mat in enumerate(inputmats)
            ],
            dim=0,
        )

        return hidden_states

    def _generate_random_numbers(n, total_sum):
        if n <= 0:
            return []

        # reset seed
        random.seed(seed)

        breaks = sorted(random.sample(range(1, total_sum), n - 1))
        random_numbers = (
            [breaks[0]]
            + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
            + [total_sum - breaks[-1]]
        )

        return random_numbers

    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len * bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

    m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)

1641
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
        if isinstance(block, TorchGroupedLinearWithPadding):
            out = block(inp_hidden_states, m_splits)
        else:
            if fp8:
                padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
                    inp_hidden_states, m_splits
                )
                padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
                out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
            else:
                out = block(inp_hidden_states, m_splits)

    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1668
@pytest.mark.parametrize("model", ["126m"])
1669
@pytest.mark.parametrize("fp8", [True])
1670
@pytest.mark.parametrize("recipe", fp8_recipes)
1671
1672
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
1673
    dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None
1674
1675
1676
):
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1677
1678
1679
1680
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if fp8 and recipe.mxfp8():  # TODO(ksivamani): debug mismatches
        pytest.skip("MXFP8 unsupported for grouped linear.")
1681
1682
    if fp8 and recipe.float8_current_scaling():
        pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
1683
1684
1685
1686
1687

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1688
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
        grouped_linear = TorchGroupedLinearWithPadding(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            fp8=fp8,
        ).eval()

1699
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
        ref_grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        ).eval()

    # Share params
    with torch.no_grad():
        inner_grouped_linear = grouped_linear.linear_fn
        for i in range(num_gemms):
            setattr(
                ref_grouped_linear,
                f"weight{i}",
                Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
            )

    outputs = _test_padding_grouped_linear_accuracy(
1721
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
1722
1723
    )
    outputs_ref = _test_padding_grouped_linear_accuracy(
1724
        ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
1725
1726
1727
1728
1729
1730
1731
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1732
1733
1734
1735
1736
1737
1738
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
    reset_rng_states()

    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

1739
    # Placeholders used for graph capture.
1740
1741
1742
1743
    static_input = torch.randn(
        config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
    )
    static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
1744
1745
1746
1747

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
    # Basic training loop.
    def train_step():
        optimizer.zero_grad(set_to_none=False)
        out = block(static_input)
        loss = loss_fn(out, static_target)
        loss.backward()
        optimizer.step()
        return out

    # Warmup steps in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            train_step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture graph.
    g = None
    static_output = None
1768
1769
1770
    if graph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
1771
1772
1773
1774
1775
1776
1777
            static_output = train_step()

    # Run with new data.
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    if graph:
1778
1779
        g.replay()
    else:
1780
        static_output = train_step()
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793

    grads = [static_input.grad]
    for p in block.parameters():
        if p.requires_grad:
            grads.append(p.grad)

    with torch.no_grad():
        output = static_output.clone()
    return output, grads


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1794
@pytest.mark.parametrize("model", ["126m"])
1795
1796
1797
1798
1799
1800
1801
def test_gpt_cuda_graph(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1802
    block_args = (
1803
1804
1805
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
1806
1807
    )
    block_kwargs = dict(
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
1818
    )
1819
1820
1821
1822
1823
    block = TransformerLayer(*block_args, **block_kwargs)
    graphed_block = TransformerLayer(*block_args, **block_kwargs)
    with torch.no_grad():
        for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
            param2.copy_(param1)
1824

1825
1826
1827
1828
    out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
    graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
    params = list(block.parameters())
    graphed_params = list(graphed_block.parameters())
1829

1830
1831
1832
1833
    # Check that results match
    assert_allclose(out, graphed_out, 1e-3)
    assert_allclose(params, graphed_params, 1e-3)
    assert_allclose(grads, graphed_grads, 1e-3)
1834
1835


1836
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
1837
1838
1839
1840
1841
1842
1843
    reset_rng_states()
    FP8GlobalStateManager.reset()

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1844
    with fp8_model_init(enabled=fp8_model_params, recipe=recipe):
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
1860
1861
1862
        )

    te_inp_hidden_states = torch.randn(
1863
1864
1865
1866
1867
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1868
1869
1870
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

1871
    with fp8_autocast(enabled=True, fp8_recipe=recipe):
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1886
@pytest.mark.parametrize("model", ["126m"])
1887
1888
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
1889
1890
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)
1891
1892
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
1893
1894
1895

    config = model_configs[model]

1896
1897
    outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
    outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe)
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909

    # Check that results match
    tols = dict(rtol=0.125, atol=0.0675)
    for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            rtol=0.125,
            atol=0.0675,
        )

1910
1911
1912

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1913
@pytest.mark.parametrize("model", ["126m"])
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
def test_transformer_layer_hidden_states_format(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
    block_sbhd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="sbhd",
1940
1941
1942
1943
1944
1945
    )

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
    block_bshd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="bshd",
1961
1962
    )

1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
    torch.manual_seed(0)
    block_thd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="thd",
        self_attn_mask_type="padding_causal",
    )

    for (n1, p1), (n2, p2), (n3, p3) in zip(
        block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
    ):
        assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
1986
1987

    x_sbhd = torch.randn(
1988
1989
1990
1991
1992
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1993

1994
    x_bshd = x_sbhd.transpose(0, 1).contiguous()
1995
1996
    x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
    x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_sbhd = block_sbhd(x_sbhd)

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_bshd = block_bshd(x_bshd)

2008
2009
2010
    # Check that results match
    torch.testing.assert_close(
        y_bshd,
2011
        y_sbhd.transpose(0, 1).contiguous(),
2012
    )
2013

2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
    # THD is not supported in float32 and on GPUs older than Ampere, skip the test here
    if dtype != torch.float32 and sm_80plus:
        # To make sure forward is also identical (just in case some module decides
        # to act fancy)
        torch.manual_seed(0)
        y_thd = block_thd(
            x_thd,
            cu_seqlens_q=x_thd_cumsum,
            cu_seqlens_kv=x_thd_cumsum,
            max_seqlen_q=config.seq_len,
            max_seqlen_kv=config.seq_len,
        )

        torch.testing.assert_close(
            y_bshd,
            y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
        )

2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"

    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    elif backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

    config = model_configs_inference[model_key]

    S = config.seq_len
    B = bs
    H = config.num_attention_heads
    D = config.hidden_size
    head_size = config.embed
    layer_number = 1

    # Limits the max size of KV-cache
    B_max = B
    S_max = S + 2

    if module == "TransformerLayer":
2063
2064
2065
2066
2067
        model = TransformerLayer(
            hidden_size=D,
            ffn_hidden_size=4 * D,
            num_attention_heads=H,
            attn_input_format=input_format,
2068
2069
            self_attn_mask_type="causal",
            enc_dec_attn_mask_type="causal",
2070
2071
2072
2073
2074
            layer_number=layer_number,
            attention_dropout=0.0,
            params_dtype=dtype,
            device="cuda",
        ).eval()
2075
2076
2077
2078
2079
2080
2081
    else:
        model = (
            MultiheadAttention(
                hidden_size=D,
                num_attention_heads=H,
                qkv_format=input_format,
                layer_number=layer_number,
2082
                attention_dropout=0.0,
2083
                attn_mask_type="causal",
2084
                params_dtype=dtype,
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
            )
            .cuda()
            .eval()
        )

    inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
    rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")

    input = torch.randn((S, B, D), dtype=dtype, device="cuda")
    if input_format == "bshd":
        input = input.transpose(0, 1).contiguous()

    incremental_output = torch.zeros_like(input)

    # Generate output for the entire sequence
2100
    full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
2101
2102
2103
2104

    # Incrementaly generate outputs using KV-cache
    for i in range(S):
        if input_format == "sbhd":
2105
            incremental_input = input[i].view(1, B, D)
2106
        else:
2107
            incremental_input = input[:, i, :].view(B, 1, D)
2108
2109
2110
2111

        line_output = model(
            hidden_states=incremental_input,
            inference_params=inference_params,
2112
2113
            rotary_pos_emb=rotary_freqs if use_RoPE else None,
        )
2114
2115
2116
2117

        inference_params.sequence_len_offset += 1

        if input_format == "sbhd":
2118
            incremental_output[i] = line_output.view(B, D)
2119
        else:
2120
            incremental_output[:, i, :] = line_output.view(B, D)
2121
2122
2123

    if module == "TransformerLayer":
        atol = {
2124
2125
            torch.float32: 5e-3,
            torch.half: 5e-3,
2126
2127
2128
2129
            torch.bfloat16: 5e-2,
        }
    else:
        atol = {
2130
2131
            torch.float32: 1e-3,
            torch.half: 1e-3,
2132
2133
2134
2135
2136
            torch.bfloat16: 1e-2,
        }

    # Check if the fully generated output matches the one generated incrementally
    assert_allclose(full_output, incremental_output, atol[dtype])
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161


@pytest.mark.parametrize(
    "shape",
    [
        (1, 127, 128, 512),
        (8, 15, 128, 512),
        (8, 1027, 128, 512),
        (16, 10027, 128, 512),
    ],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
    torch.manual_seed(0)
    z, m, k, n = shape

    dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
    m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
    assert m_splits.sum() == m and len(m_splits) == z
    m_splits = m_splits.tolist()

    if layout == "TN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2162
2163
2164
        B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        out = [torch.randn(m, n, dtype=dtype, device="cuda")]  # output
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2165
        grad = False
2166
        single_output = True
2167
2168
    elif layout == "NN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2169
2170
2171
2172
2173
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
        out = [torch.randn(m, k, dtype=dtype, device="cuda")]  # dgrad
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2174
        grad = True
2175
        single_output = True
2176
    else:  # layout == "NT"
2177
2178
2179
2180
        A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
2181
        out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # wgrad
2182
        out_ref = [o.clone() for o in out]
2183
        grad = True
2184
        single_output = False
2185
2186

    for i in range(z):
2187
        general_gemm(
2188
2189
2190
            A[i],
            B[i],
            get_workspace(),
2191
            dtype,
2192
2193
2194
2195
2196
            grad=grad,
            accumulate=accumulate,
            layout=layout,
            out=out_ref[i],
        )
2197
2198
    if single_output:
        out_ref = [torch.cat(out_ref)]
2199

2200
    general_grouped_gemm(
2201
        A,
2202
2203
        B,
        out,
2204
2205
        dtype,
        get_multi_stream_cublas_workspace(),
2206
        m_splits=m_splits,
2207
2208
2209
        grad=grad,
        accumulate=accumulate,
        layout=layout,
2210
        single_output=single_output,
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
    )

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
    "shape",
    [
        (1, 128, 128, 512),
        (8, 1024, 128, 512),
        (16, 4096, 128, 512),
    ],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    z, m, k, n = shape
2233
    m_splits = [m // z] * z
2234
2235
2236
2237
2238
2239
2240
2241

    dtype = torch.bfloat16
    A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
    B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
    out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # output
    out_ref = [o.clone() for o in out]

    # fp8 should be robust enough to this fake scale
2242
2243
    scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze()
    amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
2244

2245
2246
2247
2248
    a_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
2249
2250
            tex.DType.kFloat8E4M3,
        )
2251
        for _ in range(z)
2252
    ]
2253
2254
2255
2256
2257
    b_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
            tex.DType.kFloat8E4M3,
2258
        )
2259
        for _ in range(z)
2260
2261
    ]

2262
2263
2264
2265
2266
2267
    A_fp8 = []
    B_fp8 = []

    for i in range(z):
        A_fp8.append(a_quantizers[i](A[i]))
        B_fp8.append(b_quantizers[i](B[i]))
2268
2269
2270

    # baseline
    for i in range(z):
2271
        general_gemm(
2272
2273
2274
            A_fp8[i],
            B_fp8[i],
            get_workspace(),
2275
            dtype,
2276
2277
2278
            out=out_ref[i],
            accumulate=accumulate,
        )
2279
2280
2281
2282
2283
2284
    general_grouped_gemm(
        A_fp8,
        B_fp8,
        out,
        dtype,
        get_multi_stream_cublas_workspace(),
2285
        m_splits=m_splits,
2286
2287
        accumulate=accumulate,
    )
2288
2289
2290
2291

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340


def test_noncontiguous():
    def _create2modules(m, params):
        mod1 = m(*params)
        mod2 = m(*params)
        for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
            p2.data = p1.data.clone()

        return mod1, mod2

    def _run_module(m, inp):
        out = m(inp)
        out.sum().backward()
        ret = [out]
        if inp.grad is not None:
            ret.append(inp.grad)

        for p in m.parameters():
            if p.requires_grad:
                ret.append(p.grad)
        return ret

    a = torch.randn((128, 256), device="cuda", requires_grad=True)
    a = a.T
    assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."

    b = a.contiguous()

    # LayerNorm
    ln1, ln2 = _create2modules(LayerNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # RMSNorm
    ln1, ln2 = _create2modules(RMSNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # GEMM
    g1, g2 = _create2modules(Linear, [128, 128])
    outT = _run_module(g1, a)
    out = _run_module(g2, b)

    assert_allclose(out, outT, 1e-7)