test_numerics.py 86.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import math
6
import os
7
from typing import Dict, List, Tuple, Optional
8
import pytest
9
import random
10
11
12
13

import torch
import torch.nn as nn
from torch.nn import Parameter
yuguo's avatar
yuguo committed
14
from torch.utils.cpp_extension import IS_HIP_EXTENSION
15

16
17
18
19
20
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    fp8_autocast,
    fp8_model_init,
)
21
22
23
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
24
    attention_mask_func,
25
    is_bf16_compatible,
26
27
)
from transformer_engine.pytorch import (
28
29
30
31
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
32
    GroupedLinear,
33
34
35
36
    MultiheadAttention,
    RMSNorm,
    TransformerLayer,
    LayerNorm,
37
38
    Fp8Padding,
    Fp8Unpadding,
39
)
40
from transformer_engine.pytorch import torch_version
41
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
42
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
43
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
44
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
45
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
46
from transformer_engine.pytorch.utils import get_device_compute_capability
47
from transformer_engine.common import recipe
48
import transformer_engine_torch as tex
49
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
50

51
# Only run FP8 tests on supported devices.
52
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
53
54
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
55

56
sm_80plus = get_device_compute_capability() >= (8, 0)
57

58
seed = 1234
59
60
# Reset RNG states.
reset_rng_states()
61

62
63
64
65
if torch_version() >= (2, 7, 0):
    torch._dynamo.config.recompile_limit = 16
else:
    torch._dynamo.config.cache_size_limit = 16
66
67
68


model_configs = {
69
70
    "small": ModelConfig(1, 128, 8, 16, num_layers=4),
    "126m": ModelConfig(1, 2048, 12, 64, num_layers=12),
71
}
72
model_configs_inference = {
73
    "126m": ModelConfig(1, 256, 12, 64, num_layers=12),
74
}
75
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
76
77
78
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

79
param_types = [torch.float32, torch.float16]
80
if is_bf16_compatible():  # bf16 requires sm_80 or higher
81
82
83
84
85
86
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

87
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
88

89
90
all_normalizations = ["LayerNorm", "RMSNorm"]

91
92
mask_types = ["causal", "no_mask"]

93
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
94
95
96
97
98
99
100
101
102
103
104
105
106

if NVTE_TEST_NVINSPECT_ENABLED:
    # The numerics of all the layers should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

107
108
109
110
111
112
113
114
115

fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
116

117

118
119
120
def is_fused_attn_available(
    config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
):
121
    _, _, fused_attn_backends = get_available_attention_backends(
122
123
124
125
126
127
        config,
        qkv_dtype=dtype,
        qkv_layout=qkv_layout,
        is_training=is_training,
    )
    return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
128

129

130
131
132
133
def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


134
135
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
    """Estimated numerical error for a datatype
136

137
    Based on tolerances for torch.testing.assert_close.
138

139
140
141
142
143
144
145
146
147
148
149
    """
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    raise ValueError(f"Unsuppored dtype ({dtype})")


def assert_allclose(
150
    l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
151
) -> bool:
152
153
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
154
    for i, (t1, t2) in enumerate(zip(l1, l2)):
155
        tols = dtype_tols(t2.dtype)
156
157
        if rtol is not None:
            tols["rtol"] = rtol
158
159
        if atol is not None:
            tols["atol"] = atol
160
        result = torch.allclose(t1, t2, **tols)
161
        if not result:
162
            diff = torch.abs(t1 - t2)
163
            tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
164
165
166
167
168
169
170
171
172
173
174
175
            exceed_mask = diff > tol
            if exceed_mask.any():
                indices = torch.nonzero(exceed_mask, as_tuple=True)
                max_diff = diff[exceed_mask].max()
                max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
                max_location = [idx[max_idx].item() for idx in indices]
                msg = (
                    f"Outputs not close enough in tensor at idx={i}. "
                    f"Maximum difference at location {max_location} "
                    f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
                    f"(diff {max_diff.item()})."
                )
176
            raise AssertionError(msg)
177
178


179
180
181
182
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()
183
184


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs


class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
234
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
273
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
274
275

        # change view [b * np, sq, sk]
276
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

292

293
class TorchLayerNorm(nn.Module):
294
    def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        super().__init__()
        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.bias = nn.Parameter(torch.zeros(in_features))
        self.register_parameter("weight", self.weight)
        self.register_parameter("bias", self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weight if not self.zero_centered_gamma else 1 + self.weight
        w = w.to(torch.float32)
        b = self.bias.to(torch.float32)
        inp = x.to(torch.float32)
311
312
313
        out = torch.nn.functional.layer_norm(
            inp, (self.in_features,), weight=w, bias=b, eps=self.eps
        )
314
315
        return out.to(x.dtype)

316

317
318
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
319
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
320
321
322
323
        super().__init__()

        self.eps = eps
        self.in_features = in_features
324
        self.zero_centered_gamma = zero_centered_gamma
325

326
327
        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
328
329
330
        self.register_parameter("weight", self.weight)

    def forward(self, x):
331
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
332
333
        d_x = self.in_features

334
        rms_x2 = norm_x2 / d_x + self.eps
335
        r_rms_x = rms_x2 ** (-1.0 / 2)
336
        x_normed = x * r_rms_x
337

338
339
340
341
        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)
342

343

344
class TorchLayerNormLinear(nn.Module):
345
346
347
348
349
350
351
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float,
        normalization: str = "LayerNorm",
        zero_centered_gamma: bool = False,
352
        bias: bool = True,
353
    ):
354
        super().__init__()
355
        if normalization == "LayerNorm":
356
357
358
            self.layernorm = TorchLayerNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
359
        elif normalization == "RMSNorm":
360
361
362
            self.layernorm = TorchRMSNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
363
364
365
        else:
            raise RuntimeError("Unsupported normalization")

366
        self.linear = nn.Linear(in_features, out_features, bias=bias)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.layernorm(x))


class TorchMHA(nn.Module):
    def __init__(self, hidden_size: int, num_attention_heads: int):
        super().__init__()
        self.mhsa = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=0.1,
            bias=True,
            batch_first=False,
        )

383
384
    def forward(self, x, attention_mask=None):
        output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
385
386
387
388
        if isinstance(output, tuple):
            output = output[0]
        return output

389

390
391
392
class TorchQuickGELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(1.702 * input)
393

394

395
396
397
398
class TorchSquaredRELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (input > 0) * input * input

399

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class TorchGroupedLinearWithPadding(nn.Module):

    def __init__(
        self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
    ) -> None:
        super().__init__()

        self.padding = Fp8Padding(num_gemms)
        self.linear_fn = GroupedLinear(
            num_gemms,
            in_features,
            out_features,
            bias=bias,
            params_dtype=params_dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        )
        self.unpadding = Fp8Unpadding(num_gemms)

        self.fp8 = fp8

    def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
        if self.fp8:
            orig_m_splits = m_splits
            inp, m_splits = self.padding(inp, m_splits)

        out = self.linear_fn(inp, m_splits)

        if self.fp8:
            out = self.unpadding(out, orig_m_splits)

        return out


434
435
436
437
438
439
440
441
442
_supported_act = {
    "geglu": nn.GELU(approximate="tanh"),
    "gelu": nn.GELU(approximate="tanh"),
    "reglu": nn.ReLU(),
    "relu": nn.ReLU(),
    "swiglu": nn.SiLU(),
    "qgelu": TorchQuickGELU(),
    "srelu": TorchSquaredRELU(),
}
443

444

445
446
447
448
449
450
451
class TorchGLU(nn.Module):
    def __init__(self, activation: str):
        super().__init__()
        self.act = _supported_act[activation]

    def forward(self, x):
        shape = x.size(-1)
452
453
        a = x[..., : shape // 2]
        b = x[..., (shape // 2) :]
454
455
        a = self.act(a)
        return a * b
456

457

458
class TorchLayerNormMLP(nn.Module):
459
460
461
462
463
464
465
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        activation="gelu",
        normalization: str = "LayerNorm",
466
        bias: bool = True,
467
    ):
468
        super().__init__()
469
        if normalization == "LayerNorm":
470
            self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
471
        elif normalization == "RMSNorm":
472
            self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
473
474
        else:
            raise RuntimeError("Unsupported normalization")
475
        if "glu" in activation:
476
477
478
479
480
481
            fc1_output_features = 2 * ffn_hidden_size
            self.gelu = TorchGLU(activation)
        else:
            fc1_output_features = ffn_hidden_size
            self.gelu = _supported_act[activation]

482
483
        self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
        self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
484
485

    def forward(self, x):
486
487
        t = self.gelu(self.fc1(self.ln(x)))
        return self.fc2(t)
488
489
490


class TorchGPT(nn.Module):
491
492
493
    def __init__(
        self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
    ):
494
        super().__init__()
495
        self.ln = nn.LayerNorm(hidden_size, eps=eps)
496
        self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
497
        self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
498
        self.parallel_attention_mlp = parallel_attention_mlp
499
500
501
502

    def forward(
        self,
        x: torch.Tensor,
503
        attention_mask: Optional[torch.Tensor] = None,
504
    ) -> torch.Tensor:
505
        a = self.ln(x)
506
        b = self.causal_attn(a, attention_mask)
507
508
509
510
511
512
513
        if self.parallel_attention_mlp:
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
        else:
            x = x + nn.functional.dropout(b, p=0.1, training=self.training)
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(n, p=0.1, training=self.training)
514
515
516
        return x


517
518
519
def _test_e2e_selective_recompute(
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False
):
520
    reset_rng_states()
521
    FP8GlobalStateManager.reset()
522
523
524
525
526

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

527
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
528
529
530
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
531
            config.num_heads,
532
533
534
535
536
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
537
            kv_channels=config.kv_channels,
538
539
540
541
542
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
543
544
545
        )

    te_inp_hidden_states = torch.randn(
546
        (config.max_seqlen_q, bs, config.hidden_size),
547
548
549
550
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
551
    te_inp_hidden_states.retain_grad()
552
    te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
553

554
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
555
556
        te_out = block(
            te_inp_hidden_states,
557
            attention_mask=te_inp_attn_mask,
558
            checkpoint_core_attention=recompute,
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
573
@pytest.mark.parametrize("model", ["126m"])
574
@pytest.mark.parametrize("fp8", all_boolean)
575
@pytest.mark.parametrize("recipe", fp8_recipes)
576
@pytest.mark.parametrize("fp8_model_params", all_boolean)
577
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
578
579
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
580

581
582
    config = model_configs[model]

583
    outputs = _test_e2e_selective_recompute(
584
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False
585
586
    )
    outputs_recompute = _test_e2e_selective_recompute(
587
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True
588
    )
589
590
591
592
593
594
595

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-4
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
596

597
598
599
600
601
602
603
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
604
605


606
def _test_e2e_full_recompute(
607
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True
608
):
609
610
611
    reset_rng_states()
    FP8GlobalStateManager.reset()

612
613
614
615
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

616
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
617
        block = TransformerLayer(
618
619
            config.hidden_size,
            4 * config.hidden_size,
620
            config.num_heads,
621
622
623
624
625
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
626
            kv_channels=config.kv_channels,
627
628
629
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
630
            fuse_qkv_params=True,
631
            device="cuda",
632
        )
633

634
    te_inp_hidden_states = torch.randn(
635
        (config.max_seqlen_q, bs, config.hidden_size),
636
637
638
639
        dtype=dtype,
        device="cuda",
        requires_grad=use_reentrant,
    )
640
641
    if use_reentrant:
        te_inp_hidden_states.retain_grad()
642
    te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
643

644
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
645
646
647
648
649
650
        if recompute:
            te_out = te_checkpoint(
                block,
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
651
652
653
                distribute_saved_activations=False,
                tp_group=None,
                use_reentrant=use_reentrant,
654
655
656
657
658
659
660
661
662
663
664
            )
        else:
            te_out = block(
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

665
666
667
668
669
670
    outputs = [te_out]
    names = ["output"]
    if use_reentrant:
        outputs.append(te_inp_hidden_states.grad)
        names.append("input")
    for name, p in block.named_parameters():
671
672
        if p.requires_grad:
            outputs.append(p.grad)
673
674
675
            names.append(name)

    return outputs, names
676
677
678
679


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
680
@pytest.mark.parametrize("model", ["126m"])
681
@pytest.mark.parametrize("fp8", all_boolean)
682
@pytest.mark.parametrize("recipe", fp8_recipes)
683
@pytest.mark.parametrize("fp8_model_params", all_boolean)
684
@pytest.mark.parametrize("use_reentrant", all_boolean)
685
686
687
def test_gpt_full_activation_recompute(
    dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
688
689
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
690
691
692

    config = model_configs[model]

693
694
695
696
    if not use_reentrant:
        # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
        os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

697
    outputs, names = _test_e2e_full_recompute(
698
699
700
701
702
703
704
705
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=False,
        use_reentrant=use_reentrant,
706
707
    )
    outputs_recompute, _ = _test_e2e_full_recompute(
708
709
710
711
712
713
714
715
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=True,
        use_reentrant=use_reentrant,
716
    )
717
718
719
720
721

    if not use_reentrant:
        # Reset bias+GELU fusion flag to avoid contaminating other tests
        del os.environ["NVTE_BIAS_GELU_NVFUSION"]

722
723
724
725
726
727
728
729
730
731
732
733
734
    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-3
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
735
736
737
738
739
740


def _test_e2e_checkpointing_get_model(config, dtype):
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
741

742
743
744
    return TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
745
        config.num_heads,
746
747
748
749
750
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
751
        kv_channels=config.kv_channels,
752
753
754
755
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        params_dtype=dtype,
        device="cuda",
756
757
758
759
760
761
762
    )


def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
    reset_rng_states()

    te_inp_hidden_states = torch.randn(
763
        (config.max_seqlen_q, bs, config.hidden_size),
764
765
766
767
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
768
769
770
771
772
773
774
    te_inp_hidden_states.retain_grad()

    block = _test_e2e_checkpointing_get_model(config, dtype)

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
775
            None,
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        )
        loss = te_out.sum()
        loss.backward()

    if checkpoint:
        # This process is necessary so that we can start afresh with
        # a new model while erasing all internal state to ensure that
        # loading from a checkpoint gives bitwise identical results.
        # Since gradients are being accumulated, it is important to
        # restore them post loading the checkpoint.
        torch.save(block.state_dict(), path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

793
794
795
        _cpu_rng_state = torch.get_rng_state()
        _cuda_rng_state = torch.cuda.get_rng_state()

796
797
        del block
        block = _test_e2e_checkpointing_get_model(config, dtype)
798
        block.load_state_dict(torch.load(path, weights_only=False))
799
800
        torch.set_rng_state(_cpu_rng_state)
        torch.cuda.set_rng_state(_cuda_rng_state)
801
802
803
804
805
806
807
808
809
810

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
811
            None,
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
        )
        loss = te_out.sum()
        loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
830
@pytest.mark.parametrize("model", ["126m"])
831
832
def test_gpt_checkpointing(dtype, bs, model):
    config = model_configs[model]
833
834
    if not is_fused_attn_available(config, dtype):
        pytest.skip("No attention backend available.")
835
    outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
836
    outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
837
838
839
840
841
842
843
844
845
846
847
848

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
849
850
851
852
853
854


def _test_e2e_gpt_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
855
        (config.max_seqlen_q, bs, config.hidden_size),
856
857
858
859
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
860
    inp_hidden_states.retain_grad()
861
    inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
862

863
    out = block(inp_hidden_states, attention_mask=inp_attn_mask)
864
865
866
867
868
869
870
871
872
873
874
875
876
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
877
@pytest.mark.parametrize("model", ["small"])
878
879
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
880
    config = model_configs[model]
881
882
    if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
        pytest.skip("No attention backend available.")
883

884
885
886
    te_gpt = TransformerLayer(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
887
        num_attention_heads=config.num_heads,
888
889
890
891
892
893
894
895
896
        layernorm_epsilon=config.eps,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        params_dtype=dtype,
        fuse_qkv_params=True,
        qkv_weight_interleaved=False,
        parallel_attention_mlp=parallel_attention_mlp,
        device="cuda",
    ).eval()
897
898
899
900
901

    torch_gpt = (
        TorchGPT(
            config.hidden_size,
            config.eps,
902
            config.num_heads,
903
            parallel_attention_mlp=parallel_attention_mlp,
904
905
906
907
908
909
910
911
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
912
        torch_gpt.ln.weight = Parameter(
913
914
            te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
        )
915
        torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
916
917
918
919
920
921
922
923
924
925
926
927
        torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
            te_gpt.self_attention.layernorm_qkv.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
            te_gpt.self_attention.layernorm_qkv.bias.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
            te_gpt.self_attention.proj.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
            te_gpt.self_attention.proj.bias.clone()
        )
928
929
930
931
932
933
        torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
        torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
        torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
        torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
        torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
        torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
934
935
936
937

    te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
    torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

938
939
940
941
942
943
    atol = {
        torch.float32: 5e-3,
        torch.half: 5e-2,
        torch.bfloat16: 1e-1,
    }

944
    # Check output.
945
946
947
948
949
950
951
952
953
954
955
956
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

    # Check gradients, only for small model
    if model == "small":
        atol[torch.float32] = 5e-2
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
957
958


959
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
960
961
962
    reset_rng_states()

    inp_hidden_states = torch.randn(
963
        (config.max_seqlen_q, bs, config.hidden_size),
964
965
966
967
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
968
    inp_hidden_states.retain_grad()
969
    inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None
970

971
972
973
974
975
976
    forward_kwargs = {}
    if te:
        forward_kwargs["attn_mask_type"] = mask_type
    forward_kwargs["attention_mask"] = inp_attn_mask

    out = block(inp_hidden_states, **forward_kwargs)
977
978
979
980
981
982
983
984
985
986
987
988
989
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
990
@pytest.mark.parametrize("model", ["small"])
991
992
993
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
    config = model_configs[model]
994
995
    if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
        pytest.skip("No attention backend available.")
996

997
998
    te_mha = MultiheadAttention(
        config.hidden_size,
999
        config.num_heads,
1000
1001
1002
1003
1004
1005
        fuse_qkv_params=True,
        params_dtype=dtype,
        qkv_weight_interleaved=False,
        input_layernorm=False,
        device="cuda",
    ).eval()
1006
1007
1008
1009

    torch_mha = (
        TorchMHA(
            config.hidden_size,
1010
            config.num_heads,
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
        torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
        torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
        torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

1024
1025
    te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
    torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
1026
1027
1028
1029
1030
1031
1032

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
    # Check gradients, only for small model
    if model == "small":
        atol = {
            torch.float32: 5e-2,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1048

1049
def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, recipe=None):
1050
    reset_rng_states()
1051
1052
1053
    fp8 = recipe is not None
    if fp8:
        FP8GlobalStateManager.reset()
1054
1055

    inp_hidden_states = torch.randn(
1056
        (config.max_seqlen_q, bs, config.hidden_size),
1057
1058
1059
1060
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1061
1062
    inp_hidden_states.retain_grad()

1063
1064
1065
1066
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
        out = block(inp_hidden_states)
        if isinstance(out, (List, Tuple)):
            out = out[0]
1067
1068
    loss = out.sum()
    loss.backward()
1069
1070
    if delay_wgrad_compute:
        block.backward_dw()
1071
1072
1073
1074
1075

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
1076
1077
1078
1079
1080
            if getattr(p, "main_grad", None) is not None:
                outputs.append(p.main_grad)
                assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
            else:
                outputs.append(p.grad)
1081
1082
1083
    return outputs


1084
1085
1086
def _test_dpa_accuracy(block, bs, dtype, config):
    reset_rng_states()

1087
    mask = torch.triu(
1088
1089
        torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"),
        diagonal=1,
1090
    )
1091
    query, key, value = [
1092
        torch.randn(
1093
            (config.max_seqlen_q, bs, config.num_heads, config.kv_channels),
1094
1095
1096
1097
1098
1099
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        for _ in range(3)
    ]
1100
1101
1102
1103
1104

    query.retain_grad()
    key.retain_grad()
    value.retain_grad()

1105
    out = block(query, key, value, attention_mask=mask)
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()

    return [out, query.grad, key.grad, value.grad]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1116
@pytest.mark.parametrize("model", ["126m"])
1117
1118
1119
1120
1121
def test_dpa_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_dpa = (
        DotProductAttention(
1122
1123
            config.num_heads,
            config.kv_channels,
1124
            attention_dropout=0.0,  # disable dropout, FU uses rng differently
1125
1126
1127
1128
1129
1130
1131
        )
        .to(dtype=dtype)
        .cuda()
    )

    torch_dpa = (
        TorchDotProductAttention(
1132
            config.kv_channels,
1133
            0.0,  # dropout
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        )
        .to(dtype=dtype)
        .cuda()
    )

    te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
    torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1148
1149
1150
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)

1151

1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
class TestReturnBiasModule(nn.Module):
    def __init__(self, mod, **kwargs):
        super().__init__()
        self.te_module = mod(**kwargs)
        self.return_bias = kwargs["return_bias"]
        self.bias = kwargs["bias"]

    def forward(self, x):
        if self.return_bias:
            out, bias = self.te_module(x)
            if self.bias:
                out = out + bias
            return out
        return self.te_module(x)


1168
1169
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1170
@pytest.mark.parametrize("model", ["small"])
1171
1172
1173
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
1174
1175
    config = model_configs[model]

1176
1177
1178
1179
    te_linear = TestReturnBiasModule(
        Linear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
1180
        params_dtype=dtype,
1181
1182
        return_bias=return_bias,
        bias=bias,
1183
        device="cuda",
1184
    )
1185

1186
1187
1188
    torch_linear = torch.nn.Linear(
        config.hidden_size,
        4 * config.hidden_size,
1189
        bias=bias,
1190
1191
        device="cuda",
        dtype=dtype,
1192
    )
1193
1194
1195

    # Share params
    with torch.no_grad():
1196
1197
1198
        torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
        if bias:
            torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
1199
1200
1201
1202
1203

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

    # Check output.
1204
1205
1206
1207
1208
1209
1210
1211
1212
    if model == "small":
        tolerance = 5e-3 if dtype == torch.float32 else 5e-2
        rtol = {
            torch.float32: 1.3e-6,
            torch.half: 1e-2,
            torch.bfloat16: 2e-2,
        }
        for te_output, torch_output in zip(te_outputs, torch_outputs):
            assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
1213

1214

1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
    config = model_configs[model]

    te_linear_ref = Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=False,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    te_linear = Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=True,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    # Share params
    with torch.no_grad():
        te_linear_ref.weight = Parameter(te_linear.weight.clone())
        if bias:
            te_linear_ref.bias = Parameter(te_linear.bias.clone())
        if fuse_wgrad_accumulation:
            weight = getattr(te_linear, f"weight")
            weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
            te_linear_ref.weight.main_grad = weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True)
    te_outputs_ref = _test_granular_accuracy(
        te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
    )

1258
1259
    # Should be bit-wise match
    for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
1260
1261
1262
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1263
1264
1265
1266
1267
1268
1269
1270
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
def test_linear_accuracy_save_original_input(dtype, model, recipe):
    bs = 1
    fuse_wgrad_accumulation = True
    fp8_model_params = False
    fp8 = recipe is not None
1271

1272
1273
1274
1275
    if fp8 and recipe.delayed():
        pytest.skip("DelayedScaling recipe is not supported with save_original_input")

    config = model_configs[model]
1276
    if config.max_seqlen_q % 16 != 0 and fp8:
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
        te_linear_ref = Linear(
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            device="cuda",
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            save_original_input=False,
        ).eval()

        te_linear = Linear(
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            device="cuda",
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            save_original_input=True,
        ).eval()

    # Share params
    with torch.no_grad():
        te_linear_ref.weight = Parameter(te_linear.weight.clone())
        if fuse_wgrad_accumulation:
            weight = getattr(te_linear, f"weight")
            weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
            te_linear_ref.weight.main_grad = weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
    te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1316
1317
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1318
@pytest.mark.parametrize("model", ["126m"])
1319
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
1320
1321
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
1322
1323
    config = model_configs[model]

1324
1325
1326
1327
1328
1329
1330
    te_rmsnorm = RMSNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1331
1332

    torch_rmsnorm = (
1333
        TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())

    te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

1346
1347
1348
1349
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1350
    }
1351
1352

    # Check output.
1353
1354
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
    atol[torch.float32] = 2e-3
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1365

1366
1367
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1368
@pytest.mark.parametrize("model", ["126m"])
1369
1370
1371
1372
1373
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
    config = model_configs[model]

1374
1375
1376
1377
1378
1379
1380
    te_layernorm = LayerNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1381
1382

    torch_layernorm = (
1383
        TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
        torch_layernorm.bias = Parameter(te_layernorm.bias.clone())

    te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

1397
1398
1399
1400
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1401
    }
1402
1403

    # Check output.
1404
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
1405

1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    atol[torch.float32] = 1e-4
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1416

1417
1418
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1419
@pytest.mark.parametrize("model", ["small"])
1420
@pytest.mark.parametrize("normalization", all_normalizations)
1421
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1422
1423
1424
1425
1426
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
    dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
1427
1428
    config = model_configs[model]

1429
1430
1431
1432
1433
    te_ln_linear = TestReturnBiasModule(
        LayerNormLinear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
        eps=config.eps,
1434
1435
1436
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
1437
1438
        return_bias=return_bias,
        bias=bias,
1439
        device="cuda",
1440
    )
1441
1442
1443
1444
1445
1446

    torch_ln_linear = (
        TorchLayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
1447
            normalization=normalization,
1448
            zero_centered_gamma=zero_centered_gamma,
1449
            bias=bias,
1450
1451
1452
1453
1454
1455
1456
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1457
1458
1459
        torch_ln_linear.layernorm.weight = Parameter(
            te_ln_linear.te_module.layer_norm_weight.clone()
        )
1460
        if normalization != "RMSNorm":
1461
1462
1463
1464
1465
1466
            torch_ln_linear.layernorm.bias = Parameter(
                te_ln_linear.te_module.layer_norm_bias.clone()
            )
        torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
        if bias:
            torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
1467
1468
1469
1470

    te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

1471
1472
1473
1474
    atol = {
        torch.float32: 2.5e-4,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1475
    }
1476
1477
1478
1479
1480
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }
1481
1482

    # Check output.
1483
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1484

1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
    if model == "small":
        atol = {
            torch.float32: 1e-3,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-3,
            torch.half: 4e-2,
            torch.bfloat16: 4e-2,
        }
        # Check gradients
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1500

1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_linear_accuracy_delay_wgrad_compute(
    dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
    config = model_configs[model]

    ln_linear_ref = LayerNormLinear(
        config.hidden_size,
        4 * config.hidden_size,
        config.eps,
        bias=bias,
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
        delay_wgrad_compute=False,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    ln_linear = LayerNormLinear(
        config.hidden_size,
        4 * config.hidden_size,
        config.eps,
        bias=bias,
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
        delay_wgrad_compute=True,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    # Share params
    with torch.no_grad():
        ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone())
        if normalization != "RMSNorm":
            ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone())
        ln_linear_ref.weight = Parameter(ln_linear.weight.clone())
        if bias:
            ln_linear_ref.bias = Parameter(ln_linear.bias.clone())
        if fuse_wgrad_accumulation:
            weight = getattr(ln_linear, f"weight")
            weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
            ln_linear_ref.weight.main_grad = weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True)
    te_outputs_ref = _test_granular_accuracy(
        ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1562
1563
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1564
@pytest.mark.parametrize("model", ["small"])
1565
@pytest.mark.parametrize("activation", all_activations)
1566
@pytest.mark.parametrize("normalization", all_normalizations)
1567
1568
1569
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
1570
1571
    config = model_configs[model]

1572
1573
1574
1575
    te_ln_mlp = TestReturnBiasModule(
        LayerNormMLP,
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
1576
1577
1578
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
1579
1580
        return_bias=return_bias,
        bias=bias,
1581
        device="cuda",
1582
    )
1583
1584
1585
1586
1587

    torch_ln_mlp = (
        TorchLayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
1588
            activation=activation,
1589
            normalization=normalization,
1590
            bias=bias,
1591
1592
1593
1594
1595
1596
1597
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1598
        torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
1599
        if normalization != "RMSNorm":
1600
1601
1602
1603
1604
1605
            torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
        torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
        torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
        if bias:
            torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
            torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
1606
1607
1608
1609

    te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

1610
1611
1612
1613
1614
1615
    atol = {
        torch.float32: 2e-2,
        torch.half: 5e-2,
        torch.bfloat16: 5e-2,
    }

1616
1617
1618
1619
1620
1621
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }

1622
    # Check output.
1623
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635

    # Check gradients, only for small model
    rtol = {
        torch.float32: 1e-3,
        torch.half: 1e-2,
        torch.bfloat16: 4e-2,
    }
    atol[torch.half] = 2e-1
    atol[torch.bfloat16] = 2e-1
    if model == "small":
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
1636
1637


1638
@pytest.mark.parametrize("dtype", param_types)
1639
@pytest.mark.parametrize("bs", [2])
1640
1641
1642
1643
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
1644
    dtype, bs, model, bias, fuse_wgrad_accumulation
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
):
    config = model_configs[model]

    ln_mlp = LayerNormMLP(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        eps=config.eps,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=True,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    ln_mlp_ref = LayerNormMLP(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        eps=config.eps,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=False,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    # Share params
    with torch.no_grad():
        ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
1673
        ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
        ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
        ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
        if bias:
            ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
            ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
        if fuse_wgrad_accumulation:
            ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32)
            ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone()
            ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32)
            ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True)
    te_outputs_ref = _test_granular_accuracy(
        ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1695
def _test_grouped_linear_accuracy(
1696
1697
1698
1699
1700
1701
1702
1703
1704
    block,
    num_gemms,
    bs,
    dtype,
    config,
    recipe,
    fp8,
    fuse_wgrad_accumulation,
    delay_wgrad_compute=False,
1705
):
1706
1707
1708
1709
1710
    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
1711
        (config.max_seqlen_q, bs, config.hidden_size),
1712
1713
1714
1715
1716
1717
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

1718
    if num_gemms > 1:
1719
1720
        split_size = 1
        if fp8:
1721
            split_size = 16
1722
1723
            if recipe.mxfp8():
                split_size = 128
1724
        m = config.max_seqlen_q // split_size
1725
1726
1727
        dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
        dist.append(dist[-1])  # Manually add a zero
        m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
1728
        m_splits = m_splits * split_size
1729
        assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms
1730
    else:
1731
        m_splits = torch.tensor([config.max_seqlen_q])
1732

1733
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
        if isinstance(block, GroupedLinear):
            m_splits = m_splits * bs
            out = block(inp_hidden_states, m_splits.tolist())
        else:
            out = torch.cat(
                [
                    block[i](inp)
                    for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
                ]
            )
    loss = out.sum()
    loss.backward()
1746
1747
1748
1749
1750
1751
    if delay_wgrad_compute:
        if isinstance(block, GroupedLinear):
            block.backward_dw()
        else:
            for i in range(num_gemms):
                block[i].backward_dw()
1752
1753
1754
1755
1756

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
1757
1758
1759
1760
1761
            if getattr(p, "main_grad", None) is not None:
                outputs.append(p.main_grad)
                assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
            else:
                outputs.append(p.grad)
1762
1763
1764
    return outputs


1765
@pytest.mark.parametrize("dtype", param_types, ids=str)
1766
1767
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1768
@pytest.mark.parametrize("model", ["126m"])
1769
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
1770
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1771
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
1772
1773
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
1774
def test_grouped_linear_accuracy(
1775
1776
1777
1778
1779
1780
1781
    dtype,
    num_gemms,
    bs,
    model,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
1782
1783
    bias,
    delay_wgrad_compute,
1784
    parallel_mode=None,
1785
):
1786
    fp8 = recipe is not None
1787
    if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
1788
        pytest.skip("FP8 parameters are not supported in debug mode.")
1789
1790

    config = model_configs[model]
1791
    if config.max_seqlen_q % 16 != 0 and fp8:
1792
1793
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1794
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1795
1796
1797
1798
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
1799
            bias=bias,
1800
            params_dtype=dtype,
1801
            parallel_mode=parallel_mode,
1802
            device="cuda",
1803
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1804
            delay_wgrad_compute=delay_wgrad_compute,
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
            save_original_input=False,
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
                    bias=bias,
                    params_dtype=dtype,
                    parallel_mode=parallel_mode,
                    device="cuda",
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
            if bias:
                sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
            if fuse_wgrad_accumulation:
                weight_i = getattr(grouped_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
                sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()

    outputs_ref = _test_grouped_linear_accuracy(
        sequential_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
    )
    outputs = _test_grouped_linear_accuracy(
        grouped_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", [False])
@pytest.mark.parametrize("fuse_wgrad_accumulation", [True])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("delay_wgrad_compute", [True])
def test_grouped_linear_accuracy_save_original_input(
    dtype,
    num_gemms,
    bs,
    model,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
    bias,
    delay_wgrad_compute,
    parallel_mode=None,
):
    fp8 = recipe is not None
1883
    if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
1884
1885
1886
1887
1888
        pytest.skip("FP8 parameters are not supported in debug mode.")
    if fp8 and recipe.delayed():
        pytest.skip("DelayedScaling recipe is not supported with save_original_input")

    config = model_configs[model]
1889
    if config.max_seqlen_q % 16 != 0 and fp8:
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=bias,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            delay_wgrad_compute=delay_wgrad_compute,
            save_original_input=True,
1904
1905
1906
1907
1908
1909
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
1910
                    bias=bias,
1911
                    params_dtype=dtype,
1912
                    parallel_mode=parallel_mode,
1913
                    device="cuda",
1914
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1915
1916
1917
1918
1919
1920
1921
1922
1923
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
1924
1925
            if bias:
                sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
1926
1927
1928
1929
            if fuse_wgrad_accumulation:
                weight_i = getattr(grouped_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
                sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
1930

1931
1932
1933
1934
    # Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
    
1935
    outputs_ref = _test_grouped_linear_accuracy(
1936
1937
1938
1939
1940
1941
1942
1943
1944
        sequential_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
1945
1946
    )
    outputs = _test_grouped_linear_accuracy(
1947
1948
1949
1950
1951
1952
1953
1954
1955
        grouped_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
1956
    )
1957
1958
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
1959
1960
1961
1962
1963
1964

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1965
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
1966
def test_grouped_linear_accuracy_single_gemm(recipe):
1967
1968
1969
1970
1971
    """Split the tests to save CI time"""
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=1,
        bs=2,
1972
        model="126m",
1973
        recipe=recipe,
1974
        fp8_model_params=True,
1975
        fuse_wgrad_accumulation=True,
1976
1977
        bias=True,
        delay_wgrad_compute=False,
1978
1979
1980
    )


1981
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
1982
1983

    def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
1984
1985
1986
        align_size = 16
        if recipe.mxfp8():
            align_size = 32
1987
        padded_tokens_per_expert = [
1988
1989
            (num_tokens + align_size - 1) // align_size * align_size
            for num_tokens in tokens_per_expert
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
        ]
        hidden_states = torch.split(hidden_states, tokens_per_expert)
        padded_hidden_states = []
        for hidden_state, actual_num_tokens, padded_num_tokens in zip(
            hidden_states, tokens_per_expert, padded_tokens_per_expert
        ):
            padded_hidden_states.append(hidden_state)
            if padded_num_tokens > actual_num_tokens:
                pad_tensor = torch.zeros(
                    padded_num_tokens - actual_num_tokens,
                    hidden_state.shape[1],
                    dtype=hidden_state.dtype,
                    device=hidden_state.device,
                )
                padded_hidden_states.append(pad_tensor)
        padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
        return padded_hidden_states, padded_tokens_per_expert

    def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
        inputmats = torch.split(
            padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
        )
        hidden_states = torch.cat(
            [
                grad_output_mat[: actual_tokens_per_expert[i]]
                for i, grad_output_mat in enumerate(inputmats)
            ],
            dim=0,
        )

        return hidden_states

    def _generate_random_numbers(n, total_sum):
        if n <= 0:
            return []

        # reset seed
        random.seed(seed)

        breaks = sorted(random.sample(range(1, total_sum), n - 1))
        random_numbers = (
            [breaks[0]]
            + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
            + [total_sum - breaks[-1]]
        )

        return random_numbers

    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
2043
        (config.max_seqlen_q * bs, config.hidden_size),
2044
2045
2046
2047
2048
2049
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

2050
    m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
2051

2052
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
        if isinstance(block, TorchGroupedLinearWithPadding):
            out = block(inp_hidden_states, m_splits)
        else:
            if fp8:
                padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
                    inp_hidden_states, m_splits
                )
                padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
                out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
            else:
                out = block(inp_hidden_states, m_splits)

    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
2079
@pytest.mark.parametrize("model", ["126m"])
2080
@pytest.mark.parametrize("fp8", [True])
2081
@pytest.mark.parametrize("recipe", fp8_recipes)
2082
2083
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    parallel_mode=None,
):
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")

    config = model_configs[model]
2097
    if config.max_seqlen_q % 16 != 0 and fp8:
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
        grouped_linear = TorchGroupedLinearWithPadding(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            fp8=fp8,
        ).eval()

    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
        ref_grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
            save_original_input=False,
        ).eval()

    # Share params
    with torch.no_grad():
        inner_grouped_linear = grouped_linear.linear_fn
        for i in range(num_gemms):
            setattr(
                ref_grouped_linear,
                f"weight{i}",
                Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
            )

    outputs = _test_padding_grouped_linear_accuracy(
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
    )
    outputs_ref = _test_padding_grouped_linear_accuracy(
        ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", [False])
def test_padding_grouped_linear_accuracy_save_original_input(
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    parallel_mode=None,
2161
):
2162
2163
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
2164
2165
    if fp8 and recipe.delayed():
        pytest.skip("DelayedScaling recipe is not supported with save_original_input")
2166
2167

    config = model_configs[model]
2168
    if config.max_seqlen_q % 16 != 0 and fp8:
2169
2170
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

2171
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
        grouped_linear = TorchGroupedLinearWithPadding(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            fp8=fp8,
        ).eval()

2182
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
2183
2184
2185
2186
2187
2188
2189
2190
        ref_grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
2191
            save_original_input=True,
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
        ).eval()

    # Share params
    with torch.no_grad():
        inner_grouped_linear = grouped_linear.linear_fn
        for i in range(num_gemms):
            setattr(
                ref_grouped_linear,
                f"weight{i}",
                Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
            )

    outputs = _test_padding_grouped_linear_accuracy(
2205
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
2206
2207
    )
    outputs_ref = _test_padding_grouped_linear_accuracy(
2208
        ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
2209
2210
2211
2212
2213
2214
2215
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


2216
2217
2218
2219
2220
2221
2222
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
    reset_rng_states()

    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

2223
    # Placeholders used for graph capture.
2224
    static_input = torch.randn(
2225
2226
2227
2228
        config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
    )
    static_target = torch.randn(
        config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype
2229
    )
2230
2231
2232
2233

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
    # Basic training loop.
    def train_step():
        optimizer.zero_grad(set_to_none=False)
        out = block(static_input)
        loss = loss_fn(out, static_target)
        loss.backward()
        optimizer.step()
        return out

    # Warmup steps in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            train_step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture graph.
    g = None
    static_output = None
2254
2255
2256
    if graph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
2257
2258
2259
2260
2261
2262
2263
            static_output = train_step()

    # Run with new data.
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    if graph:
2264
2265
        g.replay()
    else:
2266
        static_output = train_step()
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279

    grads = [static_input.grad]
    for p in block.parameters():
        if p.requires_grad:
            grads.append(p.grad)

    with torch.no_grad():
        output = static_output.clone()
    return output, grads


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
2280
@pytest.mark.parametrize("model", ["126m"])
2281
def test_gpt_cuda_graph(dtype, bs, model):
2282
2283
    if NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("Cuda Graphs are not supported in debug mode.")
2284
2285
2286
2287
2288
2289
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

2290
    block_args = (
2291
2292
        config.hidden_size,
        4 * config.hidden_size,
2293
        config.num_heads,
2294
2295
    )
    block_kwargs = dict(
2296
2297
2298
2299
2300
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
2301
        kv_channels=config.kv_channels,
2302
2303
2304
2305
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
2306
    )
2307
2308
2309
2310
2311
    block = TransformerLayer(*block_args, **block_kwargs)
    graphed_block = TransformerLayer(*block_args, **block_kwargs)
    with torch.no_grad():
        for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
            param2.copy_(param1)
2312

2313
2314
2315
2316
    out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
    graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
    params = list(block.parameters())
    graphed_params = list(graphed_block.parameters())
2317

2318
2319
2320
2321
    # Check that results match
    assert_allclose(out, graphed_out, 1e-3)
    assert_allclose(params, graphed_params, 1e-3)
    assert_allclose(grads, graphed_grads, 1e-3)
2322
2323


2324
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
2325
2326
2327
2328
2329
2330
2331
    reset_rng_states()
    FP8GlobalStateManager.reset()

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

2332
    with fp8_model_init(enabled=fp8_model_params, recipe=recipe):
2333
2334
2335
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
2336
            config.num_heads,
2337
2338
2339
2340
2341
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
2342
            kv_channels=config.kv_channels,
2343
2344
2345
2346
2347
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
2348
2349
2350
        )

    te_inp_hidden_states = torch.randn(
2351
        (config.max_seqlen_q, bs, config.hidden_size),
2352
2353
2354
2355
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
2356
    te_inp_hidden_states.retain_grad()
2357
    te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
2358

2359
    with fp8_autocast(enabled=True, fp8_recipe=recipe):
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
2374
@pytest.mark.parametrize("model", ["126m"])
2375
2376
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
2377
2378
    if NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
2379
2380
2381

    config = model_configs[model]

2382
2383
    outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
    outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe)
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395

    # Check that results match
    tols = dict(rtol=0.125, atol=0.0675)
    for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            rtol=0.125,
            atol=0.0675,
        )

2396
2397
2398

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
2399
@pytest.mark.parametrize("model", ["126m"])
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
def test_transformer_layer_hidden_states_format(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
2411
2412
2413
    block_sbhd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
2414
        config.num_heads,
2415
2416
2417
2418
2419
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
2420
        kv_channels=config.kv_channels,
2421
2422
2423
2424
2425
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="sbhd",
2426
2427
2428
2429
2430
2431
    )

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
2432
2433
2434
    block_bshd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
2435
        config.num_heads,
2436
2437
2438
2439
2440
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
2441
        kv_channels=config.kv_channels,
2442
2443
2444
2445
2446
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="bshd",
2447
2448
    )

2449
2450
2451
2452
    torch.manual_seed(0)
    block_thd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
2453
        config.num_heads,
2454
2455
2456
2457
2458
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
2459
        kv_channels=config.kv_channels,
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="thd",
        self_attn_mask_type="padding_causal",
    )

    for (n1, p1), (n2, p2), (n3, p3) in zip(
        block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
    ):
        assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
2472
2473

    x_sbhd = torch.randn(
2474
        (config.max_seqlen_q, bs, config.hidden_size),
2475
2476
2477
2478
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
2479

2480
    x_bshd = x_sbhd.transpose(0, 1).contiguous()
2481
2482
    x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous()
    x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_sbhd = block_sbhd(x_sbhd)

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_bshd = block_bshd(x_bshd)

2494
2495
2496
    # Check that results match
    torch.testing.assert_close(
        y_bshd,
2497
        y_sbhd.transpose(0, 1).contiguous(),
2498
    )
2499

2500
2501
2502
2503
2504
2505
2506
2507
2508
    # THD is not supported in float32 and on GPUs older than Ampere, skip the test here
    if dtype != torch.float32 and sm_80plus:
        # To make sure forward is also identical (just in case some module decides
        # to act fancy)
        torch.manual_seed(0)
        y_thd = block_thd(
            x_thd,
            cu_seqlens_q=x_thd_cumsum,
            cu_seqlens_kv=x_thd_cumsum,
2509
2510
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
2511
2512
2513
2514
        )

        torch.testing.assert_close(
            y_bshd,
2515
            y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
2516
        )
2517

2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541

@pytest.mark.parametrize(
    "shape",
    [
        (1, 127, 128, 512),
        (8, 15, 128, 512),
        (8, 1027, 128, 512),
        (16, 10027, 128, 512),
    ],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
    torch.manual_seed(0)
    z, m, k, n = shape

    dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
    m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
    assert m_splits.sum() == m and len(m_splits) == z
    m_splits = m_splits.tolist()

    if layout == "TN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2542
2543
2544
        B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        out = [torch.randn(m, n, dtype=dtype, device="cuda")]  # output
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2545
        grad = False
2546
        single_output = True
2547
2548
    elif layout == "NN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2549
2550
2551
2552
2553
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
        out = [torch.randn(m, k, dtype=dtype, device="cuda")]  # dgrad
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2554
        grad = True
2555
        single_output = True
2556
    else:  # layout == "NT"
2557
2558
2559
2560
        A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
2561
        out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # wgrad
2562
        out_ref = [o.clone() for o in out]
2563
        grad = True
2564
        single_output = False
2565

2566
2567
    # Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
    if IS_HIP_EXTENSION:
2568
        ori_force_rocm_gemm = os.environ.get("NVTE_FORCE_ROCM_GEMM", None)
2569
2570
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"

2571
    for i in range(z):
2572
        general_gemm(
2573
2574
2575
            A[i],
            B[i],
            get_workspace(),
2576
            dtype,
2577
2578
2579
2580
2581
            grad=grad,
            accumulate=accumulate,
            layout=layout,
            out=out_ref[i],
        )
2582
2583
    if single_output:
        out_ref = [torch.cat(out_ref)]
2584

2585
    general_grouped_gemm(
2586
        A,
2587
2588
        B,
        out,
2589
2590
        dtype,
        get_multi_stream_cublas_workspace(),
2591
        m_splits=m_splits,
2592
2593
2594
        grad=grad,
        accumulate=accumulate,
        layout=layout,
2595
        single_output=single_output,
2596
    )
2597
    if IS_HIP_EXTENSION:
2598
2599
2600
2601
        if ori_force_rocm_gemm is not None:
            os.environ["NVTE_FORCE_ROCM_GEMM"] = ori_force_rocm_gemm
        else:
            del os.environ["NVTE_FORCE_ROCM_GEMM"]
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
    "shape",
    [
        (1, 128, 128, 512),
        (8, 1024, 128, 512),
        (16, 4096, 128, 512),
    ],
)
@pytest.mark.parametrize("accumulate", [False, True])
2617
def test_fp8_grouped_gemm(shape, accumulate):
2618
2619
2620
2621
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    z, m, k, n = shape
2622
    m_splits = [m // z] * z
2623
2624
2625
2626
2627
2628
2629
2630

    dtype = torch.bfloat16
    A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
    B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
    out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # output
    out_ref = [o.clone() for o in out]

    # fp8 should be robust enough to this fake scale
2631
2632
    scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze()
    amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
2633

2634
2635
2636
2637
    a_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
2638
2639
            tex.DType.kFloat8E4M3,
        )
2640
        for _ in range(z)
2641
    ]
2642
2643
2644
2645
2646
    b_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
            tex.DType.kFloat8E4M3,
2647
        )
2648
        for _ in range(z)
2649
2650
    ]

2651
2652
2653
2654
2655
2656
    A_fp8 = []
    B_fp8 = []

    for i in range(z):
        A_fp8.append(a_quantizers[i](A[i]))
        B_fp8.append(b_quantizers[i](B[i]))
2657
2658
2659

    # baseline
    for i in range(z):
2660
        general_gemm(
2661
2662
2663
            A_fp8[i],
            B_fp8[i],
            get_workspace(),
2664
            dtype,
2665
2666
2667
            out=out_ref[i],
            accumulate=accumulate,
        )
2668
2669
2670
2671
2672
2673
    general_grouped_gemm(
        A_fp8,
        B_fp8,
        out,
        dtype,
        get_multi_stream_cublas_workspace(),
2674
        m_splits=m_splits,
2675
2676
        accumulate=accumulate,
    )
2677
2678
2679
2680

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729


def test_noncontiguous():
    def _create2modules(m, params):
        mod1 = m(*params)
        mod2 = m(*params)
        for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
            p2.data = p1.data.clone()

        return mod1, mod2

    def _run_module(m, inp):
        out = m(inp)
        out.sum().backward()
        ret = [out]
        if inp.grad is not None:
            ret.append(inp.grad)

        for p in m.parameters():
            if p.requires_grad:
                ret.append(p.grad)
        return ret

    a = torch.randn((128, 256), device="cuda", requires_grad=True)
    a = a.T
    assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."

    b = a.contiguous()

    # LayerNorm
    ln1, ln2 = _create2modules(LayerNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # RMSNorm
    ln1, ln2 = _create2modules(RMSNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # GEMM
    g1, g2 = _create2modules(Linear, [128, 128])
    outT = _run_module(g1, a)
    out = _run_module(g2, b)

    assert_allclose(out, outT, 1e-7)