test_numerics.py 76.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from collections import OrderedDict
6
import math
7
import os
8
from typing import Dict, List, Tuple, Optional
9
import pytest
10
import copy
11
import random
12
13
14
15

import torch
import torch.nn as nn
from torch.nn import Parameter
yuguo's avatar
yuguo committed
16
from torch.utils.cpp_extension import IS_HIP_EXTENSION
17

18
19
20
21
22
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    fp8_autocast,
    fp8_model_init,
)
23
24
25
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
26
    attention_mask_func,
27
    is_bf16_compatible,
28
29
)
from transformer_engine.pytorch import (
30
31
32
33
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
34
    GroupedLinear,
35
36
37
38
    MultiheadAttention,
    RMSNorm,
    TransformerLayer,
    LayerNorm,
39
40
    Fp8Padding,
    Fp8Unpadding,
41
)
42
from transformer_engine.pytorch import torch_version
43
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
44
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
45
46
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
47
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
48
from transformer_engine.pytorch.utils import get_device_compute_capability
49
from transformer_engine.common import recipe
50
import transformer_engine_torch as tex
51

52
# Only run FP8 tests on supported devices.
53
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
54
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
55

56
sm_80plus = get_device_compute_capability() >= (8, 0)
57

58
59
60
61
62
63
64
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

65
66
67
68
if torch_version() >= (2, 7, 0):
    torch._dynamo.config.recompile_limit = 16
else:
    torch._dynamo.config.cache_size_limit = 16
69
70
71
72
73
74
75
76
77
78
79
80

class ModelConfig:
    def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
81
    "small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
82
83
84
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

85
86
model_configs_inference = {
    # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
87
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
88
}
89
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
90
91
92
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

93
param_types = [torch.float32, torch.float16]
94
if is_bf16_compatible():  # bf16 requires sm_80 or higher
95
96
97
98
99
100
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

101
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
102

103
104
all_normalizations = ["LayerNorm", "RMSNorm"]

105
106
mask_types = ["causal", "no_mask"]

107
108
109
fp8_recipes = [
    recipe.MXFP8BlockScaling(),
    recipe.DelayedScaling(),
110
    recipe.Float8CurrentScaling(),
111
112
]

113

114
115
116
117
def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


118
119
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
    """Estimated numerical error for a datatype
120

121
    Based on tolerances for torch.testing.assert_close.
122

123
124
125
126
127
128
129
130
131
132
133
    """
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    raise ValueError(f"Unsuppored dtype ({dtype})")


def assert_allclose(
134
    l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
135
) -> bool:
136
137
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
138
    for i, (t1, t2) in enumerate(zip(l1, l2)):
139
140
141
142
        tols = dict(atol=atol)
        if rtol is not None:
            tols["rtol"] = rtol
        result = torch.allclose(t1, t2, **tols)
143
        if not result:
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            diff = torch.abs(t1 - t2)
            tol = atol + (rtol * torch.abs(t2))
            exceed_mask = diff > tol
            if exceed_mask.any():
                indices = torch.nonzero(exceed_mask, as_tuple=True)
                max_diff = diff[exceed_mask].max()
                max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
                max_location = [idx[max_idx].item() for idx in indices]
                msg = (
                    f"Outputs not close enough in tensor at idx={i}. "
                    f"Maximum difference at location {max_location} "
                    f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
                    f"(diff {max_diff.item()})."
                )
158
            raise AssertionError(msg)
159
160
161


def reset_rng_states() -> None:
162
    """revert back to initial RNG state."""
163
    torch.set_rng_state(_cpu_rng_state)
164
165
166
167
168
169
170
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()
171
172


173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs


class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
222
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
261
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
262
263

        # change view [b * np, sq, sk]
264
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

280

281
class TorchLayerNorm(nn.Module):
282
    def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        super().__init__()
        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.bias = nn.Parameter(torch.zeros(in_features))
        self.register_parameter("weight", self.weight)
        self.register_parameter("bias", self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weight if not self.zero_centered_gamma else 1 + self.weight
        w = w.to(torch.float32)
        b = self.bias.to(torch.float32)
        inp = x.to(torch.float32)
299
300
301
        out = torch.nn.functional.layer_norm(
            inp, (self.in_features,), weight=w, bias=b, eps=self.eps
        )
302
303
        return out.to(x.dtype)

304

305
306
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
307
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
308
309
310
311
        super().__init__()

        self.eps = eps
        self.in_features = in_features
312
        self.zero_centered_gamma = zero_centered_gamma
313

314
315
        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
316
317
318
        self.register_parameter("weight", self.weight)

    def forward(self, x):
319
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
320
321
        d_x = self.in_features

322
        rms_x2 = norm_x2 / d_x + self.eps
323
        r_rms_x = rms_x2 ** (-1.0 / 2)
324
        x_normed = x * r_rms_x
325

326
327
328
329
        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)
330

331

332
class TorchLayerNormLinear(nn.Module):
333
334
335
336
337
338
339
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float,
        normalization: str = "LayerNorm",
        zero_centered_gamma: bool = False,
340
        bias: bool = True,
341
    ):
342
        super().__init__()
343
        if normalization == "LayerNorm":
344
345
346
            self.layernorm = TorchLayerNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
347
        elif normalization == "RMSNorm":
348
349
350
            self.layernorm = TorchRMSNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
351
352
353
        else:
            raise RuntimeError("Unsupported normalization")

354
        self.linear = nn.Linear(in_features, out_features, bias=bias)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.layernorm(x))


class TorchMHA(nn.Module):
    def __init__(self, hidden_size: int, num_attention_heads: int):
        super().__init__()
        self.mhsa = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=0.1,
            bias=True,
            batch_first=False,
        )

371
372
    def forward(self, x, attention_mask=None):
        output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
373
374
375
376
        if isinstance(output, tuple):
            output = output[0]
        return output

377

378
379
380
class TorchQuickGELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(1.702 * input)
381

382

383
384
385
386
class TorchSquaredRELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (input > 0) * input * input

387

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class TorchGroupedLinearWithPadding(nn.Module):

    def __init__(
        self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
    ) -> None:
        super().__init__()

        self.padding = Fp8Padding(num_gemms)
        self.linear_fn = GroupedLinear(
            num_gemms,
            in_features,
            out_features,
            bias=bias,
            params_dtype=params_dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        )
        self.unpadding = Fp8Unpadding(num_gemms)

        self.fp8 = fp8

    def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
        if self.fp8:
            orig_m_splits = m_splits
            inp, m_splits = self.padding(inp, m_splits)

        out = self.linear_fn(inp, m_splits)

        if self.fp8:
            out = self.unpadding(out, orig_m_splits)

        return out


422
423
424
425
426
427
428
429
430
_supported_act = {
    "geglu": nn.GELU(approximate="tanh"),
    "gelu": nn.GELU(approximate="tanh"),
    "reglu": nn.ReLU(),
    "relu": nn.ReLU(),
    "swiglu": nn.SiLU(),
    "qgelu": TorchQuickGELU(),
    "srelu": TorchSquaredRELU(),
}
431

432

433
434
435
436
437
438
439
class TorchGLU(nn.Module):
    def __init__(self, activation: str):
        super().__init__()
        self.act = _supported_act[activation]

    def forward(self, x):
        shape = x.size(-1)
440
441
        a = x[..., : shape // 2]
        b = x[..., (shape // 2) :]
442
443
        a = self.act(a)
        return a * b
444

445

446
class TorchLayerNormMLP(nn.Module):
447
448
449
450
451
452
453
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        activation="gelu",
        normalization: str = "LayerNorm",
454
        bias: bool = True,
455
    ):
456
        super().__init__()
457
        if normalization == "LayerNorm":
458
            self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
459
        elif normalization == "RMSNorm":
460
            self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
461
462
        else:
            raise RuntimeError("Unsupported normalization")
463
        if "glu" in activation:
464
465
466
467
468
469
            fc1_output_features = 2 * ffn_hidden_size
            self.gelu = TorchGLU(activation)
        else:
            fc1_output_features = ffn_hidden_size
            self.gelu = _supported_act[activation]

470
471
        self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
        self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
472
473

    def forward(self, x):
474
475
        t = self.gelu(self.fc1(self.ln(x)))
        return self.fc2(t)
476
477
478


class TorchGPT(nn.Module):
479
480
481
    def __init__(
        self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
    ):
482
        super().__init__()
483
        self.ln = nn.LayerNorm(hidden_size, eps=eps)
484
        self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
485
        self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
486
        self.parallel_attention_mlp = parallel_attention_mlp
487
488
489
490

    def forward(
        self,
        x: torch.Tensor,
491
        attention_mask: Optional[torch.Tensor] = None,
492
    ) -> torch.Tensor:
493
        a = self.ln(x)
494
        b = self.causal_attn(a, attention_mask)
495
496
497
498
499
500
501
        if self.parallel_attention_mlp:
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
        else:
            x = x + nn.functional.dropout(b, p=0.1, training=self.training)
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(n, p=0.1, training=self.training)
502
503
504
        return x


505
506
507
def _test_e2e_selective_recompute(
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False
):
508
    reset_rng_states()
509
    FP8GlobalStateManager.reset()
510
511
512
513
514

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

515
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
531
532
533
        )

    te_inp_hidden_states = torch.randn(
534
535
536
537
538
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
539
540
541
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

542
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
543
544
        te_out = block(
            te_inp_hidden_states,
545
            attention_mask=te_inp_attn_mask,
546
            checkpoint_core_attention=recompute,
547
548
549
550
551
552
553
554
555
556
557
558
559
560
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
561
@pytest.mark.parametrize("model", ["126m"])
562
@pytest.mark.parametrize("fp8", all_boolean)
563
@pytest.mark.parametrize("recipe", fp8_recipes)
564
@pytest.mark.parametrize("fp8_model_params", all_boolean)
565
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
566
567
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
568
569
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
570

571
572
    config = model_configs[model]

573
    outputs = _test_e2e_selective_recompute(
574
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False
575
576
    )
    outputs_recompute = _test_e2e_selective_recompute(
577
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True
578
    )
579
580
581
582
583
584
585

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-4
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
586

587
588
589
590
591
592
593
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
594
595


596
def _test_e2e_full_recompute(
597
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True
598
):
599
600
601
    reset_rng_states()
    FP8GlobalStateManager.reset()

602
603
604
605
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

606
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
607
        block = TransformerLayer(
608
609
610
611
612
613
614
615
616
617
618
619
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
620
            fuse_qkv_params=True,
621
            device="cuda",
622
        )
623

624
    te_inp_hidden_states = torch.randn(
625
626
627
628
629
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=use_reentrant,
    )
630
631
    if use_reentrant:
        te_inp_hidden_states.retain_grad()
632
633
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

634
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
635
636
637
638
639
640
        if recompute:
            te_out = te_checkpoint(
                block,
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
641
642
643
                distribute_saved_activations=False,
                tp_group=None,
                use_reentrant=use_reentrant,
644
645
646
647
648
649
650
651
652
653
654
            )
        else:
            te_out = block(
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

655
656
657
658
659
660
    outputs = [te_out]
    names = ["output"]
    if use_reentrant:
        outputs.append(te_inp_hidden_states.grad)
        names.append("input")
    for name, p in block.named_parameters():
661
662
        if p.requires_grad:
            outputs.append(p.grad)
663
664
665
            names.append(name)

    return outputs, names
666
667
668
669


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
670
@pytest.mark.parametrize("model", ["126m"])
671
@pytest.mark.parametrize("fp8", all_boolean)
672
@pytest.mark.parametrize("recipe", fp8_recipes)
673
@pytest.mark.parametrize("fp8_model_params", all_boolean)
674
@pytest.mark.parametrize("use_reentrant", all_boolean)
675
676
677
def test_gpt_full_activation_recompute(
    dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
678
679
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
680
681
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
682
683
684

    config = model_configs[model]

685
686
687
688
    if not use_reentrant:
        # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
        os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

689
    outputs, names = _test_e2e_full_recompute(
690
691
692
693
694
695
696
697
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=False,
        use_reentrant=use_reentrant,
698
699
    )
    outputs_recompute, _ = _test_e2e_full_recompute(
700
701
702
703
704
705
706
707
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=True,
        use_reentrant=use_reentrant,
708
    )
709
710
711
712
713

    if not use_reentrant:
        # Reset bias+GELU fusion flag to avoid contaminating other tests
        del os.environ["NVTE_BIAS_GELU_NVFUSION"]

714
715
716
717
718
719
720
721
722
723
724
725
726
    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-3
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
727
728
729
730
731
732


def _test_e2e_checkpointing_get_model(config, dtype):
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
733

734
735
736
737
738
739
740
741
742
743
744
745
746
747
    return TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        params_dtype=dtype,
        device="cuda",
748
749
750
751
752
753
754
    )


def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
    reset_rng_states()

    te_inp_hidden_states = torch.randn(
755
756
757
758
759
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
760
761
762
763
764
765
766
    te_inp_hidden_states.retain_grad()

    block = _test_e2e_checkpointing_get_model(config, dtype)

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
767
            None,
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
        )
        loss = te_out.sum()
        loss.backward()

    if checkpoint:
        # This process is necessary so that we can start afresh with
        # a new model while erasing all internal state to ensure that
        # loading from a checkpoint gives bitwise identical results.
        # Since gradients are being accumulated, it is important to
        # restore them post loading the checkpoint.
        torch.save(block.state_dict(), path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

785
786
787
788
        global _cpu_rng_state, _cuda_rng_state
        _cpu_rng_state = torch.get_rng_state()
        _cuda_rng_state = torch.cuda.get_rng_state()

789
790
        del block
        block = _test_e2e_checkpointing_get_model(config, dtype)
791
        block.load_state_dict(torch.load(path, weights_only=False))
792
        reset_rng_states()
793
794
795
796
797
798
799
800
801
802

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
803
            None,
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
        )
        loss = te_out.sum()
        loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
822
@pytest.mark.parametrize("model", ["126m"])
823
824
825
def test_gpt_checkpointing(dtype, bs, model):
    config = model_configs[model]
    outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
826
    outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
827
828
829
830
831
832
833
834
835
836
837
838

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
839
840
841
842
843
844


def _test_e2e_gpt_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
845
846
847
848
849
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
850
851
852
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len)

853
    out = block(inp_hidden_states, attention_mask=inp_attn_mask)
854
855
856
857
858
859
860
861
862
863
864
865
866
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
867
@pytest.mark.parametrize("model", ["small"])
868
869
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
870
871
    config = model_configs[model]

872
873
874
875
876
877
878
879
880
881
882
883
884
    te_gpt = TransformerLayer(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        num_attention_heads=config.num_attention_heads,
        layernorm_epsilon=config.eps,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        params_dtype=dtype,
        fuse_qkv_params=True,
        qkv_weight_interleaved=False,
        parallel_attention_mlp=parallel_attention_mlp,
        device="cuda",
    ).eval()
885
886
887
888
889
890

    torch_gpt = (
        TorchGPT(
            config.hidden_size,
            config.eps,
            config.num_attention_heads,
891
            parallel_attention_mlp=parallel_attention_mlp,
892
893
894
895
896
897
898
899
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
900
        torch_gpt.ln.weight = Parameter(
901
902
            te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
        )
903
        torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
904
905
906
907
908
909
910
911
912
913
914
915
        torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
            te_gpt.self_attention.layernorm_qkv.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
            te_gpt.self_attention.layernorm_qkv.bias.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
            te_gpt.self_attention.proj.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
            te_gpt.self_attention.proj.bias.clone()
        )
916
917
918
919
920
921
        torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
        torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
        torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
        torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
        torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
        torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
922
923
924
925

    te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
    torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

926
927
928
929
930
931
    atol = {
        torch.float32: 5e-3,
        torch.half: 5e-2,
        torch.bfloat16: 1e-1,
    }

932
    # Check output.
933
934
935
936
937
938
939
940
941
942
943
944
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

    # Check gradients, only for small model
    if model == "small":
        atol[torch.float32] = 5e-2
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
945
946


947
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
948
949
950
    reset_rng_states()

    inp_hidden_states = torch.randn(
951
952
953
954
955
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
956
957
958
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None

959
960
961
962
963
964
    forward_kwargs = {}
    if te:
        forward_kwargs["attn_mask_type"] = mask_type
    forward_kwargs["attention_mask"] = inp_attn_mask

    out = block(inp_hidden_states, **forward_kwargs)
965
966
967
968
969
970
971
972
973
974
975
976
977
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
978
@pytest.mark.parametrize("model", ["small"])
979
980
981
982
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
    config = model_configs[model]

983
984
985
986
987
988
989
990
991
    te_mha = MultiheadAttention(
        config.hidden_size,
        config.num_attention_heads,
        fuse_qkv_params=True,
        params_dtype=dtype,
        qkv_weight_interleaved=False,
        input_layernorm=False,
        device="cuda",
    ).eval()
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009

    torch_mha = (
        TorchMHA(
            config.hidden_size,
            config.num_attention_heads,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
        torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
        torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
        torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

1010
1011
    te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
    torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
1012
1013
1014
1015
1016
1017
1018

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
    # Check gradients, only for small model
    if model == "small":
        atol = {
            torch.float32: 5e-2,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1034

1035
1036
1037
1038
def _test_granular_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
1039
1040
1041
1042
1043
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1044
1045
1046
    inp_hidden_states.retain_grad()

    out = block(inp_hidden_states)
1047
1048
    if isinstance(out, (List, Tuple)):
        out = out[0]
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


1060
1061
1062
def _test_dpa_accuracy(block, bs, dtype, config):
    reset_rng_states()

1063
1064
1065
    mask = torch.triu(
        torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
    )
1066
    query, key, value = [
1067
1068
1069
1070
1071
1072
1073
1074
        torch.randn(
            (config.seq_len, bs, config.num_attention_heads, config.embed),
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        for _ in range(3)
    ]
1075
1076
1077
1078
1079

    query.retain_grad()
    key.retain_grad()
    value.retain_grad()

1080
    out = block(query, key, value, attention_mask=mask)
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()

    return [out, query.grad, key.grad, value.grad]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1091
@pytest.mark.parametrize("model", ["126m"])
1092
1093
1094
1095
1096
1097
1098
def test_dpa_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_dpa = (
        DotProductAttention(
            config.num_attention_heads,
            config.embed,
1099
            attention_dropout=0.0,  # disable dropout, FU uses rng differently
1100
1101
1102
1103
1104
1105
1106
1107
        )
        .to(dtype=dtype)
        .cuda()
    )

    torch_dpa = (
        TorchDotProductAttention(
            config.embed,
1108
            0.0,  # dropout
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        )
        .to(dtype=dtype)
        .cuda()
    )

    te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
    torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1123
1124
1125
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)

1126

1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
class TestReturnBiasModule(nn.Module):
    def __init__(self, mod, **kwargs):
        super().__init__()
        self.te_module = mod(**kwargs)
        self.return_bias = kwargs["return_bias"]
        self.bias = kwargs["bias"]

    def forward(self, x):
        if self.return_bias:
            out, bias = self.te_module(x)
            if self.bias:
                out = out + bias
            return out
        return self.te_module(x)


1143
1144
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1145
@pytest.mark.parametrize("model", ["small"])
1146
1147
1148
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
1149
1150
    config = model_configs[model]

1151
1152
1153
1154
    te_linear = TestReturnBiasModule(
        Linear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
1155
        params_dtype=dtype,
1156
1157
        return_bias=return_bias,
        bias=bias,
1158
        device="cuda",
1159
    )
1160

1161
1162
1163
    torch_linear = torch.nn.Linear(
        config.hidden_size,
        4 * config.hidden_size,
1164
        bias=bias,
1165
1166
        device="cuda",
        dtype=dtype,
1167
    )
1168
1169
1170

    # Share params
    with torch.no_grad():
1171
1172
1173
        torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
        if bias:
            torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
1174
1175
1176
1177
1178

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

    # Check output.
1179
1180
1181
1182
1183
1184
1185
1186
1187
    if model == "small":
        tolerance = 5e-3 if dtype == torch.float32 else 5e-2
        rtol = {
            torch.float32: 1.3e-6,
            torch.half: 1e-2,
            torch.bfloat16: 2e-2,
        }
        for te_output, torch_output in zip(te_outputs, torch_outputs):
            assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
1188

1189

1190
1191
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1192
@pytest.mark.parametrize("model", ["126m"])
1193
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
1194
1195
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
1196
1197
    config = model_configs[model]

1198
1199
1200
1201
1202
1203
1204
    te_rmsnorm = RMSNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1205
1206

    torch_rmsnorm = (
1207
        TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())

    te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

1220
1221
1222
1223
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1224
    }
1225
1226

    # Check output.
1227
1228
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
    atol[torch.float32] = 2e-3
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1239

1240
1241
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1242
@pytest.mark.parametrize("model", ["126m"])
1243
1244
1245
1246
1247
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
    config = model_configs[model]

1248
1249
1250
1251
1252
1253
1254
    te_layernorm = LayerNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1255
1256

    torch_layernorm = (
1257
        TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
        torch_layernorm.bias = Parameter(te_layernorm.bias.clone())

    te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

1271
1272
1273
1274
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1275
    }
1276
1277

    # Check output.
1278
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
1279

1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    atol[torch.float32] = 1e-4
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1290

1291
1292
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1293
@pytest.mark.parametrize("model", ["small"])
1294
@pytest.mark.parametrize("normalization", all_normalizations)
1295
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1296
1297
1298
1299
1300
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
    dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
1301
1302
    config = model_configs[model]

1303
1304
1305
1306
1307
    te_ln_linear = TestReturnBiasModule(
        LayerNormLinear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
        eps=config.eps,
1308
1309
1310
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
1311
1312
        return_bias=return_bias,
        bias=bias,
1313
        device="cuda",
1314
    )
1315
1316
1317
1318
1319
1320

    torch_ln_linear = (
        TorchLayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
1321
            normalization=normalization,
1322
            zero_centered_gamma=zero_centered_gamma,
1323
            bias=bias,
1324
1325
1326
1327
1328
1329
1330
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1331
1332
1333
        torch_ln_linear.layernorm.weight = Parameter(
            te_ln_linear.te_module.layer_norm_weight.clone()
        )
1334
        if normalization != "RMSNorm":
1335
1336
1337
1338
1339
1340
            torch_ln_linear.layernorm.bias = Parameter(
                te_ln_linear.te_module.layer_norm_bias.clone()
            )
        torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
        if bias:
            torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
1341
1342
1343
1344

    te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

1345
1346
1347
1348
    atol = {
        torch.float32: 2.5e-4,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1349
    }
1350
1351
1352
1353
1354
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }
1355
1356

    # Check output.
1357
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1358

1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
    if model == "small":
        atol = {
            torch.float32: 1e-3,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-3,
            torch.half: 4e-2,
            torch.bfloat16: 4e-2,
        }
        # Check gradients
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1374

1375
1376
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1377
@pytest.mark.parametrize("model", ["small"])
1378
@pytest.mark.parametrize("activation", all_activations)
1379
@pytest.mark.parametrize("normalization", all_normalizations)
1380
1381
1382
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
1383
1384
    config = model_configs[model]

1385
1386
1387
1388
    te_ln_mlp = TestReturnBiasModule(
        LayerNormMLP,
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
1389
1390
1391
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
1392
1393
        return_bias=return_bias,
        bias=bias,
1394
        device="cuda",
1395
    )
1396
1397
1398
1399
1400

    torch_ln_mlp = (
        TorchLayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
1401
            activation=activation,
1402
            normalization=normalization,
1403
            bias=bias,
1404
1405
1406
1407
1408
1409
1410
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1411
        torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
1412
        if normalization != "RMSNorm":
1413
1414
1415
1416
1417
1418
            torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
        torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
        torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
        if bias:
            torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
            torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
1419
1420
1421
1422

    te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

1423
1424
1425
1426
1427
1428
    atol = {
        torch.float32: 2e-2,
        torch.half: 5e-2,
        torch.bfloat16: 5e-2,
    }

1429
1430
1431
1432
1433
1434
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }

1435
    # Check output.
1436
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448

    # Check gradients, only for small model
    rtol = {
        torch.float32: 1e-3,
        torch.half: 1e-2,
        torch.bfloat16: 4e-2,
    }
    atol[torch.half] = 2e-1
    atol[torch.bfloat16] = 2e-1
    if model == "small":
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
1449
1450


1451
1452
1453
def _test_grouped_linear_accuracy(
    block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

1466
    if num_gemms > 1:
1467
1468
1469
1470
1471
1472
1473
        split_size = 1
        if fp8:
            if recipe.delayed():
                split_size = 16
            if recipe.mxfp8():
                split_size = 128
        m = config.seq_len // split_size
1474
1475
1476
        dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
        dist.append(dist[-1])  # Manually add a zero
        m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
1477
        m_splits = m_splits * split_size
1478
1479
1480
        assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
    else:
        m_splits = torch.tensor([config.seq_len])
1481

1482
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        if isinstance(block, GroupedLinear):
            m_splits = m_splits * bs
            out = block(inp_hidden_states, m_splits.tolist())
        else:
            out = torch.cat(
                [
                    block[i](inp)
                    for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
                ]
            )
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
1500
1501
1502
1503
1504
            if getattr(p, "main_grad", None) is not None:
                outputs.append(p.main_grad)
                assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
            else:
                outputs.append(p.grad)
1505
1506
1507
1508
1509
1510
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1511
@pytest.mark.parametrize("model", ["126m"])
1512
@pytest.mark.parametrize("fp8", all_boolean)
1513
@pytest.mark.parametrize("recipe", fp8_recipes)
1514
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1515
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
1516
def test_grouped_linear_accuracy(
1517
1518
1519
1520
1521
1522
1523
1524
1525
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
    parallel_mode=None,
1526
):
1527
1528
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1529
1530
1531
1532
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if fp8 and recipe.mxfp8():  # TODO(ksivamani): debug mismatches
        pytest.skip("MXFP8 unsupported for grouped linear.")
1533
1534
    if fp8 and recipe.float8_current_scaling():
        pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
1535
1536
1537
1538
1539

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1540
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1541
1542
1543
1544
1545
1546
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=True,
            params_dtype=dtype,
1547
            parallel_mode=parallel_mode,
1548
            device="cuda",
1549
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1550
1551
1552
1553
1554
1555
1556
1557
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
                    bias=True,
                    params_dtype=dtype,
1558
                    parallel_mode=parallel_mode,
1559
                    device="cuda",
1560
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
            sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
1571
1572
1573
1574
            if fuse_wgrad_accumulation:
                weight_i = getattr(grouped_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
                sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
1575

1576
1577
1578
1579
    # Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
    
1580
    outputs_ref = _test_grouped_linear_accuracy(
1581
        sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
1582
1583
    )
    outputs = _test_grouped_linear_accuracy(
1584
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
1585
1586
1587
1588
1589
1590
1591
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1592
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
1593
1594
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
1595
    """Split the tests to save CI time"""
1596
1597
1598
1599
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=6,
        bs=2,
1600
        model="126m",
1601
        fp8=True,
1602
        recipe=recipe,
1603
1604
        fp8_model_params=True,
        parallel_mode=parallel_mode,
1605
        fuse_wgrad_accumulation=True,
1606
1607
1608
    )


1609
1610
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_single_gemm(recipe):
1611
1612
1613
1614
1615
    """Split the tests to save CI time"""
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=1,
        bs=2,
1616
        model="126m",
1617
        fp8=True,
1618
        recipe=recipe,
1619
        fp8_model_params=True,
1620
        fuse_wgrad_accumulation=True,
1621
1622
1623
    )


1624
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691

    def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
        """Padding tensor shapes to multiples of 16."""
        padded_tokens_per_expert = [
            (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
        ]
        hidden_states = torch.split(hidden_states, tokens_per_expert)
        padded_hidden_states = []
        for hidden_state, actual_num_tokens, padded_num_tokens in zip(
            hidden_states, tokens_per_expert, padded_tokens_per_expert
        ):
            padded_hidden_states.append(hidden_state)
            if padded_num_tokens > actual_num_tokens:
                pad_tensor = torch.zeros(
                    padded_num_tokens - actual_num_tokens,
                    hidden_state.shape[1],
                    dtype=hidden_state.dtype,
                    device=hidden_state.device,
                )
                padded_hidden_states.append(pad_tensor)
        padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
        return padded_hidden_states, padded_tokens_per_expert

    def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
        inputmats = torch.split(
            padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
        )
        hidden_states = torch.cat(
            [
                grad_output_mat[: actual_tokens_per_expert[i]]
                for i, grad_output_mat in enumerate(inputmats)
            ],
            dim=0,
        )

        return hidden_states

    def _generate_random_numbers(n, total_sum):
        if n <= 0:
            return []

        # reset seed
        random.seed(seed)

        breaks = sorted(random.sample(range(1, total_sum), n - 1))
        random_numbers = (
            [breaks[0]]
            + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
            + [total_sum - breaks[-1]]
        )

        return random_numbers

    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len * bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

    m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)

1692
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
        if isinstance(block, TorchGroupedLinearWithPadding):
            out = block(inp_hidden_states, m_splits)
        else:
            if fp8:
                padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
                    inp_hidden_states, m_splits
                )
                padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
                out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
            else:
                out = block(inp_hidden_states, m_splits)

    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1719
@pytest.mark.parametrize("model", ["126m"])
1720
@pytest.mark.parametrize("fp8", [True])
1721
@pytest.mark.parametrize("recipe", fp8_recipes)
1722
1723
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
1724
    dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None
1725
1726
1727
):
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1728
1729
1730
1731
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if fp8 and recipe.mxfp8():  # TODO(ksivamani): debug mismatches
        pytest.skip("MXFP8 unsupported for grouped linear.")
1732
1733
    if fp8 and recipe.float8_current_scaling():
        pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
1734
1735
1736
1737
1738

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1739
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
        grouped_linear = TorchGroupedLinearWithPadding(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            fp8=fp8,
        ).eval()

1750
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
        ref_grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        ).eval()

    # Share params
    with torch.no_grad():
        inner_grouped_linear = grouped_linear.linear_fn
        for i in range(num_gemms):
            setattr(
                ref_grouped_linear,
                f"weight{i}",
                Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
            )

    outputs = _test_padding_grouped_linear_accuracy(
1772
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
1773
1774
    )
    outputs_ref = _test_padding_grouped_linear_accuracy(
1775
        ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
1776
1777
1778
1779
1780
1781
1782
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1783
1784
1785
1786
1787
1788
1789
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
    reset_rng_states()

    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

1790
    # Placeholders used for graph capture.
1791
1792
1793
1794
    static_input = torch.randn(
        config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
    )
    static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
1795
1796
1797
1798

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
    # Basic training loop.
    def train_step():
        optimizer.zero_grad(set_to_none=False)
        out = block(static_input)
        loss = loss_fn(out, static_target)
        loss.backward()
        optimizer.step()
        return out

    # Warmup steps in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            train_step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture graph.
    g = None
    static_output = None
1819
1820
1821
    if graph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
1822
1823
1824
1825
1826
1827
1828
            static_output = train_step()

    # Run with new data.
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    if graph:
1829
1830
        g.replay()
    else:
1831
        static_output = train_step()
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844

    grads = [static_input.grad]
    for p in block.parameters():
        if p.requires_grad:
            grads.append(p.grad)

    with torch.no_grad():
        output = static_output.clone()
    return output, grads


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1845
@pytest.mark.parametrize("model", ["126m"])
1846
1847
1848
1849
1850
1851
1852
def test_gpt_cuda_graph(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1853
    block_args = (
1854
1855
1856
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
1857
1858
    )
    block_kwargs = dict(
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
1869
    )
1870
1871
1872
1873
1874
    block = TransformerLayer(*block_args, **block_kwargs)
    graphed_block = TransformerLayer(*block_args, **block_kwargs)
    with torch.no_grad():
        for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
            param2.copy_(param1)
1875

1876
1877
1878
1879
    out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
    graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
    params = list(block.parameters())
    graphed_params = list(graphed_block.parameters())
1880

1881
1882
1883
1884
    # Check that results match
    assert_allclose(out, graphed_out, 1e-3)
    assert_allclose(params, graphed_params, 1e-3)
    assert_allclose(grads, graphed_grads, 1e-3)
1885
1886


1887
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
1888
1889
1890
1891
1892
1893
1894
    reset_rng_states()
    FP8GlobalStateManager.reset()

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1895
    with fp8_model_init(enabled=fp8_model_params, recipe=recipe):
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
1911
1912
1913
        )

    te_inp_hidden_states = torch.randn(
1914
1915
1916
1917
1918
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1919
1920
1921
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

1922
    with fp8_autocast(enabled=True, fp8_recipe=recipe):
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1937
@pytest.mark.parametrize("model", ["126m"])
1938
1939
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
1940
1941
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)
1942
1943
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
1944
1945
1946

    config = model_configs[model]

1947
1948
    outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
    outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe)
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960

    # Check that results match
    tols = dict(rtol=0.125, atol=0.0675)
    for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            rtol=0.125,
            atol=0.0675,
        )

1961
1962
1963

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1964
@pytest.mark.parametrize("model", ["126m"])
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
def test_transformer_layer_hidden_states_format(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
    block_sbhd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="sbhd",
1991
1992
1993
1994
1995
1996
    )

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
    block_bshd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="bshd",
2012
2013
    )

2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
    torch.manual_seed(0)
    block_thd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="thd",
        self_attn_mask_type="padding_causal",
    )

    for (n1, p1), (n2, p2), (n3, p3) in zip(
        block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
    ):
        assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
2037
2038

    x_sbhd = torch.randn(
2039
2040
2041
2042
2043
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
2044

2045
    x_bshd = x_sbhd.transpose(0, 1).contiguous()
2046
2047
    x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
    x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_sbhd = block_sbhd(x_sbhd)

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_bshd = block_bshd(x_bshd)

2059
2060
2061
    # Check that results match
    torch.testing.assert_close(
        y_bshd,
2062
        y_sbhd.transpose(0, 1).contiguous(),
2063
    )
2064

2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
    # THD is not supported in float32 and on GPUs older than Ampere, skip the test here
    if dtype != torch.float32 and sm_80plus:
        # To make sure forward is also identical (just in case some module decides
        # to act fancy)
        torch.manual_seed(0)
        y_thd = block_thd(
            x_thd,
            cu_seqlens_q=x_thd_cumsum,
            cu_seqlens_kv=x_thd_cumsum,
            max_seqlen_q=config.seq_len,
            max_seqlen_kv=config.seq_len,
        )

        torch.testing.assert_close(
            y_bshd,
            y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
        )

2083
2084
2085
2086
2087
2088
2089
2090

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
2091
2092
2093
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
    reset_rng_states()
2094
2095
2096
    
    if backend in ["FusedAttention"]:
        pytest.skip("Not support FusedAttention")
2097
2098
2099
2100
2101
    if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
        pytest.skip("FusedAttention and FlashAttention do not support FP32")
    if use_RoPE:
        pytest.skip("KV cache does not support starting positions for RoPE")

2102
2103
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2104
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2105
2106
2107
2108
2109

    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    elif backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2110
2111
    elif backend == "UnfusedAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123

    config = model_configs_inference[model_key]

    S = config.seq_len
    B = bs
    H = config.num_attention_heads
    D = config.hidden_size
    head_size = config.embed
    layer_number = 1

    # Limits the max size of KV-cache
    B_max = B
2124
    S_max = S
2125
2126

    if module == "TransformerLayer":
2127
2128
2129
2130
2131
        model = TransformerLayer(
            hidden_size=D,
            ffn_hidden_size=4 * D,
            num_attention_heads=H,
            attn_input_format=input_format,
2132
2133
            self_attn_mask_type="causal",
            enc_dec_attn_mask_type="causal",
2134
2135
2136
2137
2138
            layer_number=layer_number,
            attention_dropout=0.0,
            params_dtype=dtype,
            device="cuda",
        ).eval()
2139
2140
2141
2142
2143
2144
2145
    else:
        model = (
            MultiheadAttention(
                hidden_size=D,
                num_attention_heads=H,
                qkv_format=input_format,
                layer_number=layer_number,
2146
                attention_dropout=0.0,
2147
                attn_mask_type="causal",
2148
                params_dtype=dtype,
2149
2150
2151
2152
2153
            )
            .cuda()
            .eval()
        )

2154
2155
    inference_params = InferenceParams(
        max_batch_size=B_max,
2156
        max_sequence_length=S_max,
2157
2158
2159
2160
2161
2162
2163
2164
        num_heads_kv=H,
        head_dim_k=head_size,
        dtype=dtype,
        is_paged=is_paged,
        total_num_pages=int(B_max * S_max / 256),
        page_size=256,
    )

2165
2166
2167
2168
2169
2170
2171
2172
2173
    rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")

    input = torch.randn((S, B, D), dtype=dtype, device="cuda")
    if input_format == "bshd":
        input = input.transpose(0, 1).contiguous()

    incremental_output = torch.zeros_like(input)

    # Generate output for the entire sequence
2174
    full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
2175
2176

    # Incrementaly generate outputs using KV-cache
2177
    step_dict = OrderedDict(zip(list(range(B)), [1] * B))
2178
    for i in range(S):
2179
2180
        inference_params.pre_step(step_dict)

2181
        if input_format == "sbhd":
2182
            incremental_input = input[i].view(1, B, D)
2183
        else:
2184
            incremental_input = input[:, i, :].view(B, 1, D)
2185

2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
        seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
        cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
        cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
        cu_seqlens_kv = cu_seqlens_q.clone()

        mask_type = "padding"
        kwargs = {}
        if module == "TransformerLayer":
            kwargs["self_attn_mask_type"] = mask_type
        else:
            kwargs["attn_mask_type"] = mask_type
2197
2198
2199
        line_output = model(
            hidden_states=incremental_input,
            inference_params=inference_params,
2200
            rotary_pos_emb=rotary_freqs if use_RoPE else None,
2201
2202
2203
2204
2205
            **kwargs,
            max_seqlen_q=1,
            max_seqlen_kv=S,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
2206
        )
2207
2208

        if input_format == "sbhd":
2209
            incremental_output[i, :, :] = line_output.view(B, D)
2210
        else:
2211
            incremental_output[:, i, :] = line_output.view(B, D)
2212
2213
2214

    if module == "TransformerLayer":
        atol = {
2215
2216
            torch.float32: 5e-3,
            torch.half: 5e-3,
2217
2218
2219
2220
            torch.bfloat16: 5e-2,
        }
    else:
        atol = {
2221
2222
            torch.float32: 1e-3,
            torch.half: 1e-3,
2223
2224
2225
2226
2227
            torch.bfloat16: 1e-2,
        }

    # Check if the fully generated output matches the one generated incrementally
    assert_allclose(full_output, incremental_output, atol[dtype])
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252


@pytest.mark.parametrize(
    "shape",
    [
        (1, 127, 128, 512),
        (8, 15, 128, 512),
        (8, 1027, 128, 512),
        (16, 10027, 128, 512),
    ],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
    torch.manual_seed(0)
    z, m, k, n = shape

    dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
    m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
    assert m_splits.sum() == m and len(m_splits) == z
    m_splits = m_splits.tolist()

    if layout == "TN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2253
2254
2255
        B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        out = [torch.randn(m, n, dtype=dtype, device="cuda")]  # output
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2256
        grad = False
2257
        single_output = True
2258
2259
    elif layout == "NN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2260
2261
2262
2263
2264
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
        out = [torch.randn(m, k, dtype=dtype, device="cuda")]  # dgrad
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2265
        grad = True
2266
        single_output = True
2267
    else:  # layout == "NT"
2268
2269
2270
2271
        A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
2272
        out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # wgrad
2273
        out_ref = [o.clone() for o in out]
2274
        grad = True
2275
        single_output = False
2276

2277
2278
2279
2280
    # Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"

2281
    for i in range(z):
2282
        general_gemm(
2283
2284
2285
            A[i],
            B[i],
            get_workspace(),
2286
            dtype,
2287
2288
2289
2290
2291
            grad=grad,
            accumulate=accumulate,
            layout=layout,
            out=out_ref[i],
        )
2292
2293
    if single_output:
        out_ref = [torch.cat(out_ref)]
2294

2295
    general_grouped_gemm(
2296
        A,
2297
2298
        B,
        out,
2299
2300
        dtype,
        get_multi_stream_cublas_workspace(),
2301
        m_splits=m_splits,
2302
2303
2304
        grad=grad,
        accumulate=accumulate,
        layout=layout,
2305
        single_output=single_output,
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
    )

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
    "shape",
    [
        (1, 128, 128, 512),
        (8, 1024, 128, 512),
        (16, 4096, 128, 512),
    ],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    z, m, k, n = shape
2328
    m_splits = [m // z] * z
2329
2330
2331
2332
2333
2334
2335
2336

    dtype = torch.bfloat16
    A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
    B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
    out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # output
    out_ref = [o.clone() for o in out]

    # fp8 should be robust enough to this fake scale
2337
2338
    scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze()
    amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
2339

2340
2341
2342
2343
    a_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
2344
2345
            tex.DType.kFloat8E4M3,
        )
2346
        for _ in range(z)
2347
    ]
2348
2349
2350
2351
2352
    b_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
            tex.DType.kFloat8E4M3,
2353
        )
2354
        for _ in range(z)
2355
2356
    ]

2357
2358
2359
2360
2361
2362
    A_fp8 = []
    B_fp8 = []

    for i in range(z):
        A_fp8.append(a_quantizers[i](A[i]))
        B_fp8.append(b_quantizers[i](B[i]))
2363
2364
2365

    # baseline
    for i in range(z):
2366
        general_gemm(
2367
2368
2369
            A_fp8[i],
            B_fp8[i],
            get_workspace(),
2370
            dtype,
2371
2372
2373
            out=out_ref[i],
            accumulate=accumulate,
        )
2374
2375
2376
2377
2378
2379
    general_grouped_gemm(
        A_fp8,
        B_fp8,
        out,
        dtype,
        get_multi_stream_cublas_workspace(),
2380
        m_splits=m_splits,
2381
2382
        accumulate=accumulate,
    )
2383
2384
2385
2386

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435


def test_noncontiguous():
    def _create2modules(m, params):
        mod1 = m(*params)
        mod2 = m(*params)
        for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
            p2.data = p1.data.clone()

        return mod1, mod2

    def _run_module(m, inp):
        out = m(inp)
        out.sum().backward()
        ret = [out]
        if inp.grad is not None:
            ret.append(inp.grad)

        for p in m.parameters():
            if p.requires_grad:
                ret.append(p.grad)
        return ret

    a = torch.randn((128, 256), device="cuda", requires_grad=True)
    a = a.T
    assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."

    b = a.contiguous()

    # LayerNorm
    ln1, ln2 = _create2modules(LayerNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # RMSNorm
    ln1, ln2 = _create2modules(RMSNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # GEMM
    g1, g2 = _create2modules(Linear, [128, 128])
    outT = _run_module(g1, a)
    out = _run_module(g2, b)

    assert_allclose(out, outT, 1e-7)