test_numerics.py 76.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from collections import OrderedDict
6
import math
7
import os
8
from typing import Dict, List, Tuple, Optional
9
import pytest
10
import copy
11
import random
12
13
14
15
16

import torch
import torch.nn as nn
from torch.nn import Parameter

17
18
19
20
21
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    fp8_autocast,
    fp8_model_init,
)
22
23
24
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
25
    attention_mask_func,
26
    is_bf16_compatible,
27
28
)
from transformer_engine.pytorch import (
29
30
31
32
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
33
    GroupedLinear,
34
35
36
37
    MultiheadAttention,
    RMSNorm,
    TransformerLayer,
    LayerNorm,
38
39
    Fp8Padding,
    Fp8Unpadding,
40
)
41
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
42
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
43
44
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
45
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
46
from transformer_engine.pytorch.utils import get_device_compute_capability
47
from transformer_engine.common import recipe
48
import transformer_engine_torch as tex
49

50
# Only run FP8 tests on supported devices.
51
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
52
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
53
54
55
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
56

57
sm_80plus = get_device_compute_capability() >= (8, 0)
58

59
60
61
62
63
64
65
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

66
67
torch._dynamo.config.recompile_limit = 16

68
69
70
71
72
73
74
75
76
77
78
79

class ModelConfig:
    def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
80
    "small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
81
82
83
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

84
85
model_configs_inference = {
    # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
86
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
87
}
88
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
89
90
91
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

92
param_types = [torch.float32, torch.float16]
93
if is_bf16_compatible():  # bf16 requires sm_80 or higher
94
95
96
97
98
99
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

100
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
101

102
103
all_normalizations = ["LayerNorm", "RMSNorm"]

104
105
mask_types = ["causal", "no_mask"]

106
107
108
fp8_recipes = [
    recipe.MXFP8BlockScaling(),
    recipe.DelayedScaling(),
109
    recipe.Float8CurrentScaling(),
110
    recipe.Float8BlockScaling(),
111
112
]

113

114
115
116
117
def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


118
119
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
    """Estimated numerical error for a datatype
120

121
    Based on tolerances for torch.testing.assert_close.
122

123
124
125
126
127
128
129
130
131
132
133
    """
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    raise ValueError(f"Unsuppored dtype ({dtype})")


def assert_allclose(
134
    l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
135
) -> bool:
136
137
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
138
    for i, (t1, t2) in enumerate(zip(l1, l2)):
139
140
141
142
        tols = dict(atol=atol)
        if rtol is not None:
            tols["rtol"] = rtol
        result = torch.allclose(t1, t2, **tols)
143
        if not result:
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            diff = torch.abs(t1 - t2)
            tol = atol + (rtol * torch.abs(t2))
            exceed_mask = diff > tol
            if exceed_mask.any():
                indices = torch.nonzero(exceed_mask, as_tuple=True)
                max_diff = diff[exceed_mask].max()
                max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
                max_location = [idx[max_idx].item() for idx in indices]
                msg = (
                    f"Outputs not close enough in tensor at idx={i}. "
                    f"Maximum difference at location {max_location} "
                    f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
                    f"(diff {max_diff.item()})."
                )
158
            raise AssertionError(msg)
159
160
161


def reset_rng_states() -> None:
162
    """revert back to initial RNG state."""
163
    torch.set_rng_state(_cpu_rng_state)
164
165
166
167
168
169
170
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()
171
172


173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs


class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
222
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
261
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
262
263

        # change view [b * np, sq, sk]
264
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

280

281
class TorchLayerNorm(nn.Module):
282
    def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        super().__init__()
        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.bias = nn.Parameter(torch.zeros(in_features))
        self.register_parameter("weight", self.weight)
        self.register_parameter("bias", self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weight if not self.zero_centered_gamma else 1 + self.weight
        w = w.to(torch.float32)
        b = self.bias.to(torch.float32)
        inp = x.to(torch.float32)
299
300
301
        out = torch.nn.functional.layer_norm(
            inp, (self.in_features,), weight=w, bias=b, eps=self.eps
        )
302
303
        return out.to(x.dtype)

304

305
306
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
307
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
308
309
310
311
        super().__init__()

        self.eps = eps
        self.in_features = in_features
312
        self.zero_centered_gamma = zero_centered_gamma
313

314
315
        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
316
317
318
        self.register_parameter("weight", self.weight)

    def forward(self, x):
319
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
320
321
        d_x = self.in_features

322
        rms_x2 = norm_x2 / d_x + self.eps
323
        r_rms_x = rms_x2 ** (-1.0 / 2)
324
        x_normed = x * r_rms_x
325

326
327
328
329
        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)
330

331

332
class TorchLayerNormLinear(nn.Module):
333
334
335
336
337
338
339
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float,
        normalization: str = "LayerNorm",
        zero_centered_gamma: bool = False,
340
        bias: bool = True,
341
    ):
342
        super().__init__()
343
        if normalization == "LayerNorm":
344
345
346
            self.layernorm = TorchLayerNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
347
        elif normalization == "RMSNorm":
348
349
350
            self.layernorm = TorchRMSNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
351
352
353
        else:
            raise RuntimeError("Unsupported normalization")

354
        self.linear = nn.Linear(in_features, out_features, bias=bias)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.layernorm(x))


class TorchMHA(nn.Module):
    def __init__(self, hidden_size: int, num_attention_heads: int):
        super().__init__()
        self.mhsa = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=0.1,
            bias=True,
            batch_first=False,
        )

371
372
    def forward(self, x, attention_mask=None):
        output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
373
374
375
376
        if isinstance(output, tuple):
            output = output[0]
        return output

377

378
379
380
class TorchQuickGELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(1.702 * input)
381

382

383
384
385
386
class TorchSquaredRELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (input > 0) * input * input

387

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class TorchGroupedLinearWithPadding(nn.Module):

    def __init__(
        self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
    ) -> None:
        super().__init__()

        self.padding = Fp8Padding(num_gemms)
        self.linear_fn = GroupedLinear(
            num_gemms,
            in_features,
            out_features,
            bias=bias,
            params_dtype=params_dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        )
        self.unpadding = Fp8Unpadding(num_gemms)

        self.fp8 = fp8

    def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
        if self.fp8:
            orig_m_splits = m_splits
            inp, m_splits = self.padding(inp, m_splits)

        out = self.linear_fn(inp, m_splits)

        if self.fp8:
            out = self.unpadding(out, orig_m_splits)

        return out


422
423
424
425
426
427
428
429
430
_supported_act = {
    "geglu": nn.GELU(approximate="tanh"),
    "gelu": nn.GELU(approximate="tanh"),
    "reglu": nn.ReLU(),
    "relu": nn.ReLU(),
    "swiglu": nn.SiLU(),
    "qgelu": TorchQuickGELU(),
    "srelu": TorchSquaredRELU(),
}
431

432

433
434
435
436
437
438
439
class TorchGLU(nn.Module):
    def __init__(self, activation: str):
        super().__init__()
        self.act = _supported_act[activation]

    def forward(self, x):
        shape = x.size(-1)
440
441
        a = x[..., : shape // 2]
        b = x[..., (shape // 2) :]
442
443
        a = self.act(a)
        return a * b
444

445

446
class TorchLayerNormMLP(nn.Module):
447
448
449
450
451
452
453
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        activation="gelu",
        normalization: str = "LayerNorm",
454
        bias: bool = True,
455
    ):
456
        super().__init__()
457
        if normalization == "LayerNorm":
458
            self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
459
        elif normalization == "RMSNorm":
460
            self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
461
462
        else:
            raise RuntimeError("Unsupported normalization")
463
        if "glu" in activation:
464
465
466
467
468
469
            fc1_output_features = 2 * ffn_hidden_size
            self.gelu = TorchGLU(activation)
        else:
            fc1_output_features = ffn_hidden_size
            self.gelu = _supported_act[activation]

470
471
        self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
        self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
472
473

    def forward(self, x):
474
475
        t = self.gelu(self.fc1(self.ln(x)))
        return self.fc2(t)
476
477
478


class TorchGPT(nn.Module):
479
480
481
    def __init__(
        self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
    ):
482
        super().__init__()
483
        self.ln = nn.LayerNorm(hidden_size, eps=eps)
484
        self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
485
        self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
486
        self.parallel_attention_mlp = parallel_attention_mlp
487
488
489
490

    def forward(
        self,
        x: torch.Tensor,
491
        attention_mask: Optional[torch.Tensor] = None,
492
    ) -> torch.Tensor:
493
        a = self.ln(x)
494
        b = self.causal_attn(a, attention_mask)
495
496
497
498
499
500
501
        if self.parallel_attention_mlp:
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
        else:
            x = x + nn.functional.dropout(b, p=0.1, training=self.training)
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(n, p=0.1, training=self.training)
502
503
504
        return x


505
506
507
def _test_e2e_selective_recompute(
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False
):
508
    reset_rng_states()
509
    FP8GlobalStateManager.reset()
510
511
512
513
514

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

515
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
531
532
533
        )

    te_inp_hidden_states = torch.randn(
534
535
536
537
538
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
539
540
541
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

542
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
543
544
        te_out = block(
            te_inp_hidden_states,
545
            attention_mask=te_inp_attn_mask,
546
            checkpoint_core_attention=recompute,
547
548
549
550
551
552
553
554
555
556
557
558
559
560
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
561
@pytest.mark.parametrize("model", ["126m"])
562
@pytest.mark.parametrize("fp8", all_boolean)
563
@pytest.mark.parametrize("recipe", fp8_recipes)
564
@pytest.mark.parametrize("fp8_model_params", all_boolean)
565
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
566
567
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
568
569
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
570
571
    if recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
572

573
574
    config = model_configs[model]

575
    outputs = _test_e2e_selective_recompute(
576
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False
577
578
    )
    outputs_recompute = _test_e2e_selective_recompute(
579
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True
580
    )
581
582
583
584
585
586
587

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-4
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
588

589
590
591
592
593
594
595
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
596
597


598
def _test_e2e_full_recompute(
599
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True
600
):
601
602
603
    reset_rng_states()
    FP8GlobalStateManager.reset()

604
605
606
607
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

608
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
609
        block = TransformerLayer(
610
611
612
613
614
615
616
617
618
619
620
621
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
622
            fuse_qkv_params=True,
623
            device="cuda",
624
        )
625

626
    te_inp_hidden_states = torch.randn(
627
628
629
630
631
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=use_reentrant,
    )
632
633
    if use_reentrant:
        te_inp_hidden_states.retain_grad()
634
635
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

636
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
637
638
639
640
641
642
        if recompute:
            te_out = te_checkpoint(
                block,
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
643
644
645
                distribute_saved_activations=False,
                tp_group=None,
                use_reentrant=use_reentrant,
646
647
648
649
650
651
652
653
654
655
656
            )
        else:
            te_out = block(
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

657
658
659
660
661
662
    outputs = [te_out]
    names = ["output"]
    if use_reentrant:
        outputs.append(te_inp_hidden_states.grad)
        names.append("input")
    for name, p in block.named_parameters():
663
664
        if p.requires_grad:
            outputs.append(p.grad)
665
666
667
            names.append(name)

    return outputs, names
668
669
670
671


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
672
@pytest.mark.parametrize("model", ["126m"])
673
@pytest.mark.parametrize("fp8", all_boolean)
674
@pytest.mark.parametrize("recipe", fp8_recipes)
675
@pytest.mark.parametrize("fp8_model_params", all_boolean)
676
@pytest.mark.parametrize("use_reentrant", all_boolean)
677
678
679
def test_gpt_full_activation_recompute(
    dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
680
681
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
682
683
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
684
685
    if recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
686
687
688

    config = model_configs[model]

689
690
691
692
    if not use_reentrant:
        # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
        os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

693
    outputs, names = _test_e2e_full_recompute(
694
695
696
697
698
699
700
701
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=False,
        use_reentrant=use_reentrant,
702
703
    )
    outputs_recompute, _ = _test_e2e_full_recompute(
704
705
706
707
708
709
710
711
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=True,
        use_reentrant=use_reentrant,
712
    )
713
714
715
716
717

    if not use_reentrant:
        # Reset bias+GELU fusion flag to avoid contaminating other tests
        del os.environ["NVTE_BIAS_GELU_NVFUSION"]

718
719
720
721
722
723
724
725
726
727
728
729
730
    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-3
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
731
732
733
734
735
736


def _test_e2e_checkpointing_get_model(config, dtype):
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
737

738
739
740
741
742
743
744
745
746
747
748
749
750
751
    return TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        params_dtype=dtype,
        device="cuda",
752
753
754
755
756
757
758
    )


def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
    reset_rng_states()

    te_inp_hidden_states = torch.randn(
759
760
761
762
763
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
764
765
766
767
768
769
770
    te_inp_hidden_states.retain_grad()

    block = _test_e2e_checkpointing_get_model(config, dtype)

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
771
            None,
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
        )
        loss = te_out.sum()
        loss.backward()

    if checkpoint:
        # This process is necessary so that we can start afresh with
        # a new model while erasing all internal state to ensure that
        # loading from a checkpoint gives bitwise identical results.
        # Since gradients are being accumulated, it is important to
        # restore them post loading the checkpoint.
        torch.save(block.state_dict(), path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

789
790
791
792
        global _cpu_rng_state, _cuda_rng_state
        _cpu_rng_state = torch.get_rng_state()
        _cuda_rng_state = torch.cuda.get_rng_state()

793
794
        del block
        block = _test_e2e_checkpointing_get_model(config, dtype)
795
        block.load_state_dict(torch.load(path, weights_only=False))
796
        reset_rng_states()
797
798
799
800
801
802
803
804
805
806

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
807
            None,
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        )
        loss = te_out.sum()
        loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
826
@pytest.mark.parametrize("model", ["126m"])
827
828
829
def test_gpt_checkpointing(dtype, bs, model):
    config = model_configs[model]
    outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
830
    outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
831
832
833
834
835
836
837
838
839
840
841
842

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
843
844
845
846
847
848


def _test_e2e_gpt_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
849
850
851
852
853
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
854
855
856
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len)

857
    out = block(inp_hidden_states, attention_mask=inp_attn_mask)
858
859
860
861
862
863
864
865
866
867
868
869
870
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
871
@pytest.mark.parametrize("model", ["small"])
872
873
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
874
875
    config = model_configs[model]

876
877
878
879
880
881
882
883
884
885
886
887
888
    te_gpt = TransformerLayer(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        num_attention_heads=config.num_attention_heads,
        layernorm_epsilon=config.eps,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        params_dtype=dtype,
        fuse_qkv_params=True,
        qkv_weight_interleaved=False,
        parallel_attention_mlp=parallel_attention_mlp,
        device="cuda",
    ).eval()
889
890
891
892
893
894

    torch_gpt = (
        TorchGPT(
            config.hidden_size,
            config.eps,
            config.num_attention_heads,
895
            parallel_attention_mlp=parallel_attention_mlp,
896
897
898
899
900
901
902
903
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
904
        torch_gpt.ln.weight = Parameter(
905
906
            te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
        )
907
        torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
908
909
910
911
912
913
914
915
916
917
918
919
        torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
            te_gpt.self_attention.layernorm_qkv.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
            te_gpt.self_attention.layernorm_qkv.bias.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
            te_gpt.self_attention.proj.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
            te_gpt.self_attention.proj.bias.clone()
        )
920
921
922
923
924
925
        torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
        torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
        torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
        torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
        torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
        torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
926
927
928
929

    te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
    torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

930
931
932
933
934
935
    atol = {
        torch.float32: 5e-3,
        torch.half: 5e-2,
        torch.bfloat16: 1e-1,
    }

936
    # Check output.
937
938
939
940
941
942
943
944
945
946
947
948
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

    # Check gradients, only for small model
    if model == "small":
        atol[torch.float32] = 5e-2
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
949
950


951
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
952
953
954
    reset_rng_states()

    inp_hidden_states = torch.randn(
955
956
957
958
959
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
960
961
962
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None

963
964
965
966
967
968
    forward_kwargs = {}
    if te:
        forward_kwargs["attn_mask_type"] = mask_type
    forward_kwargs["attention_mask"] = inp_attn_mask

    out = block(inp_hidden_states, **forward_kwargs)
969
970
971
972
973
974
975
976
977
978
979
980
981
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
982
@pytest.mark.parametrize("model", ["small"])
983
984
985
986
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
    config = model_configs[model]

987
988
989
990
991
992
993
994
995
    te_mha = MultiheadAttention(
        config.hidden_size,
        config.num_attention_heads,
        fuse_qkv_params=True,
        params_dtype=dtype,
        qkv_weight_interleaved=False,
        input_layernorm=False,
        device="cuda",
    ).eval()
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013

    torch_mha = (
        TorchMHA(
            config.hidden_size,
            config.num_attention_heads,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
        torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
        torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
        torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

1014
1015
    te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
    torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
1016
1017
1018
1019
1020
1021
1022

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
    # Check gradients, only for small model
    if model == "small":
        atol = {
            torch.float32: 5e-2,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1038

1039
1040
1041
1042
def _test_granular_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
1043
1044
1045
1046
1047
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1048
1049
1050
    inp_hidden_states.retain_grad()

    out = block(inp_hidden_states)
1051
1052
    if isinstance(out, (List, Tuple)):
        out = out[0]
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


1064
1065
1066
def _test_dpa_accuracy(block, bs, dtype, config):
    reset_rng_states()

1067
1068
1069
    mask = torch.triu(
        torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
    )
1070
    query, key, value = [
1071
1072
1073
1074
1075
1076
1077
1078
        torch.randn(
            (config.seq_len, bs, config.num_attention_heads, config.embed),
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        for _ in range(3)
    ]
1079
1080
1081
1082
1083

    query.retain_grad()
    key.retain_grad()
    value.retain_grad()

1084
    out = block(query, key, value, attention_mask=mask)
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()

    return [out, query.grad, key.grad, value.grad]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1095
@pytest.mark.parametrize("model", ["126m"])
1096
1097
1098
1099
1100
1101
1102
def test_dpa_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_dpa = (
        DotProductAttention(
            config.num_attention_heads,
            config.embed,
1103
            attention_dropout=0.0,  # disable dropout, FU uses rng differently
1104
1105
1106
1107
1108
1109
1110
1111
        )
        .to(dtype=dtype)
        .cuda()
    )

    torch_dpa = (
        TorchDotProductAttention(
            config.embed,
1112
            0.0,  # dropout
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        )
        .to(dtype=dtype)
        .cuda()
    )

    te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
    torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1127
1128
1129
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)

1130

1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
class TestReturnBiasModule(nn.Module):
    def __init__(self, mod, **kwargs):
        super().__init__()
        self.te_module = mod(**kwargs)
        self.return_bias = kwargs["return_bias"]
        self.bias = kwargs["bias"]

    def forward(self, x):
        if self.return_bias:
            out, bias = self.te_module(x)
            if self.bias:
                out = out + bias
            return out
        return self.te_module(x)


1147
1148
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1149
@pytest.mark.parametrize("model", ["small"])
1150
1151
1152
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
1153
1154
    config = model_configs[model]

1155
1156
1157
1158
    te_linear = TestReturnBiasModule(
        Linear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
1159
        params_dtype=dtype,
1160
1161
        return_bias=return_bias,
        bias=bias,
1162
        device="cuda",
1163
    )
1164

1165
1166
1167
    torch_linear = torch.nn.Linear(
        config.hidden_size,
        4 * config.hidden_size,
1168
        bias=bias,
1169
1170
        device="cuda",
        dtype=dtype,
1171
    )
1172
1173
1174

    # Share params
    with torch.no_grad():
1175
1176
1177
        torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
        if bias:
            torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
1178
1179
1180
1181
1182

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

    # Check output.
1183
1184
1185
1186
1187
1188
1189
1190
1191
    if model == "small":
        tolerance = 5e-3 if dtype == torch.float32 else 5e-2
        rtol = {
            torch.float32: 1.3e-6,
            torch.half: 1e-2,
            torch.bfloat16: 2e-2,
        }
        for te_output, torch_output in zip(te_outputs, torch_outputs):
            assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
1192

1193

1194
1195
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1196
@pytest.mark.parametrize("model", ["126m"])
1197
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
1198
1199
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
1200
1201
    config = model_configs[model]

1202
1203
1204
1205
1206
1207
1208
    te_rmsnorm = RMSNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1209
1210

    torch_rmsnorm = (
1211
        TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())

    te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

1224
1225
1226
1227
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1228
    }
1229
1230

    # Check output.
1231
1232
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    atol[torch.float32] = 2e-3
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1243

1244
1245
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1246
@pytest.mark.parametrize("model", ["126m"])
1247
1248
1249
1250
1251
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
    config = model_configs[model]

1252
1253
1254
1255
1256
1257
1258
    te_layernorm = LayerNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1259
1260

    torch_layernorm = (
1261
        TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
        torch_layernorm.bias = Parameter(te_layernorm.bias.clone())

    te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

1275
1276
1277
1278
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1279
    }
1280
1281

    # Check output.
1282
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
1283

1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    atol[torch.float32] = 1e-4
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1294

1295
1296
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1297
@pytest.mark.parametrize("model", ["small"])
1298
@pytest.mark.parametrize("normalization", all_normalizations)
1299
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1300
1301
1302
1303
1304
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
    dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
1305
1306
    config = model_configs[model]

1307
1308
1309
1310
1311
    te_ln_linear = TestReturnBiasModule(
        LayerNormLinear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
        eps=config.eps,
1312
1313
1314
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
1315
1316
        return_bias=return_bias,
        bias=bias,
1317
        device="cuda",
1318
    )
1319
1320
1321
1322
1323
1324

    torch_ln_linear = (
        TorchLayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
1325
            normalization=normalization,
1326
            zero_centered_gamma=zero_centered_gamma,
1327
            bias=bias,
1328
1329
1330
1331
1332
1333
1334
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1335
1336
1337
        torch_ln_linear.layernorm.weight = Parameter(
            te_ln_linear.te_module.layer_norm_weight.clone()
        )
1338
        if normalization != "RMSNorm":
1339
1340
1341
1342
1343
1344
            torch_ln_linear.layernorm.bias = Parameter(
                te_ln_linear.te_module.layer_norm_bias.clone()
            )
        torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
        if bias:
            torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
1345
1346
1347
1348

    te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

1349
1350
1351
1352
    atol = {
        torch.float32: 2.5e-4,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1353
    }
1354
1355
1356
1357
1358
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }
1359
1360

    # Check output.
1361
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1362

1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
    if model == "small":
        atol = {
            torch.float32: 1e-3,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-3,
            torch.half: 4e-2,
            torch.bfloat16: 4e-2,
        }
        # Check gradients
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1378

1379
1380
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1381
@pytest.mark.parametrize("model", ["small"])
1382
@pytest.mark.parametrize("activation", all_activations)
1383
@pytest.mark.parametrize("normalization", all_normalizations)
1384
1385
1386
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
1387
1388
    config = model_configs[model]

1389
1390
1391
1392
    te_ln_mlp = TestReturnBiasModule(
        LayerNormMLP,
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
1393
1394
1395
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
1396
1397
        return_bias=return_bias,
        bias=bias,
1398
        device="cuda",
1399
    )
1400
1401
1402
1403
1404

    torch_ln_mlp = (
        TorchLayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
1405
            activation=activation,
1406
            normalization=normalization,
1407
            bias=bias,
1408
1409
1410
1411
1412
1413
1414
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1415
        torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
1416
        if normalization != "RMSNorm":
1417
1418
1419
1420
1421
1422
            torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
        torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
        torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
        if bias:
            torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
            torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
1423
1424
1425
1426

    te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

1427
1428
1429
1430
1431
1432
    atol = {
        torch.float32: 2e-2,
        torch.half: 5e-2,
        torch.bfloat16: 5e-2,
    }

1433
1434
1435
1436
1437
1438
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }

1439
    # Check output.
1440
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452

    # Check gradients, only for small model
    rtol = {
        torch.float32: 1e-3,
        torch.half: 1e-2,
        torch.bfloat16: 4e-2,
    }
    atol[torch.half] = 2e-1
    atol[torch.bfloat16] = 2e-1
    if model == "small":
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
1453
1454


1455
1456
1457
def _test_grouped_linear_accuracy(
    block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

1470
    if num_gemms > 1:
1471
1472
1473
1474
1475
1476
1477
        split_size = 1
        if fp8:
            if recipe.delayed():
                split_size = 16
            if recipe.mxfp8():
                split_size = 128
        m = config.seq_len // split_size
1478
1479
1480
        dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
        dist.append(dist[-1])  # Manually add a zero
        m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
1481
        m_splits = m_splits * split_size
1482
1483
1484
        assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
    else:
        m_splits = torch.tensor([config.seq_len])
1485

1486
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
        if isinstance(block, GroupedLinear):
            m_splits = m_splits * bs
            out = block(inp_hidden_states, m_splits.tolist())
        else:
            out = torch.cat(
                [
                    block[i](inp)
                    for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
                ]
            )
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
1504
1505
1506
1507
1508
            if getattr(p, "main_grad", None) is not None:
                outputs.append(p.main_grad)
                assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
            else:
                outputs.append(p.grad)
1509
1510
1511
1512
1513
1514
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1515
@pytest.mark.parametrize("model", ["126m"])
1516
@pytest.mark.parametrize("fp8", all_boolean)
1517
@pytest.mark.parametrize("recipe", fp8_recipes)
1518
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1519
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
1520
def test_grouped_linear_accuracy(
1521
1522
1523
1524
1525
1526
1527
1528
1529
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
    parallel_mode=None,
1530
):
1531
1532
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1533
1534
1535
1536
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if fp8 and recipe.mxfp8():  # TODO(ksivamani): debug mismatches
        pytest.skip("MXFP8 unsupported for grouped linear.")
1537
1538
    if fp8 and recipe.float8_current_scaling():
        pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
1539
1540
    if recipe.float8_block_scaling():
        pytest.skip("Grouped linear for FP8 blockwise unsupported.")
1541
1542
1543
1544
1545

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1546
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1547
1548
1549
1550
1551
1552
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=True,
            params_dtype=dtype,
1553
            parallel_mode=parallel_mode,
1554
            device="cuda",
1555
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1556
1557
1558
1559
1560
1561
1562
1563
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
                    bias=True,
                    params_dtype=dtype,
1564
                    parallel_mode=parallel_mode,
1565
                    device="cuda",
1566
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
            sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
1577
1578
1579
1580
            if fuse_wgrad_accumulation:
                weight_i = getattr(grouped_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
                sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
1581
1582

    outputs_ref = _test_grouped_linear_accuracy(
1583
        sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
1584
1585
    )
    outputs = _test_grouped_linear_accuracy(
1586
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
1587
1588
1589
1590
1591
1592
1593
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1594
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
1595
1596
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
1597
    """Split the tests to save CI time"""
1598
1599
1600
1601
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=6,
        bs=2,
1602
        model="126m",
1603
        fp8=True,
1604
        recipe=recipe,
1605
1606
        fp8_model_params=True,
        parallel_mode=parallel_mode,
1607
        fuse_wgrad_accumulation=True,
1608
1609
1610
    )


1611
1612
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_grouped_linear_accuracy_single_gemm(recipe):
1613
1614
1615
1616
1617
    """Split the tests to save CI time"""
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=1,
        bs=2,
1618
        model="126m",
1619
        fp8=True,
1620
        recipe=recipe,
1621
        fp8_model_params=True,
1622
        fuse_wgrad_accumulation=True,
1623
1624
1625
    )


1626
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693

    def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
        """Padding tensor shapes to multiples of 16."""
        padded_tokens_per_expert = [
            (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
        ]
        hidden_states = torch.split(hidden_states, tokens_per_expert)
        padded_hidden_states = []
        for hidden_state, actual_num_tokens, padded_num_tokens in zip(
            hidden_states, tokens_per_expert, padded_tokens_per_expert
        ):
            padded_hidden_states.append(hidden_state)
            if padded_num_tokens > actual_num_tokens:
                pad_tensor = torch.zeros(
                    padded_num_tokens - actual_num_tokens,
                    hidden_state.shape[1],
                    dtype=hidden_state.dtype,
                    device=hidden_state.device,
                )
                padded_hidden_states.append(pad_tensor)
        padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
        return padded_hidden_states, padded_tokens_per_expert

    def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
        inputmats = torch.split(
            padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
        )
        hidden_states = torch.cat(
            [
                grad_output_mat[: actual_tokens_per_expert[i]]
                for i, grad_output_mat in enumerate(inputmats)
            ],
            dim=0,
        )

        return hidden_states

    def _generate_random_numbers(n, total_sum):
        if n <= 0:
            return []

        # reset seed
        random.seed(seed)

        breaks = sorted(random.sample(range(1, total_sum), n - 1))
        random_numbers = (
            [breaks[0]]
            + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
            + [total_sum - breaks[-1]]
        )

        return random_numbers

    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len * bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

    m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)

1694
    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
        if isinstance(block, TorchGroupedLinearWithPadding):
            out = block(inp_hidden_states, m_splits)
        else:
            if fp8:
                padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
                    inp_hidden_states, m_splits
                )
                padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
                out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
            else:
                out = block(inp_hidden_states, m_splits)

    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1721
@pytest.mark.parametrize("model", ["126m"])
1722
@pytest.mark.parametrize("fp8", [True])
1723
@pytest.mark.parametrize("recipe", fp8_recipes)
1724
1725
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
1726
    dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None
1727
1728
1729
):
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1730
1731
1732
1733
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if fp8 and recipe.mxfp8():  # TODO(ksivamani): debug mismatches
        pytest.skip("MXFP8 unsupported for grouped linear.")
1734
1735
    if fp8 and recipe.float8_current_scaling():
        pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
1736
1737
    if recipe.float8_block_scaling():
        pytest.skip("Float8 block scaling unsupported for grouped linear.")
1738
1739
1740
1741
1742

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1743
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
        grouped_linear = TorchGroupedLinearWithPadding(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            fp8=fp8,
        ).eval()

1754
    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
        ref_grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        ).eval()

    # Share params
    with torch.no_grad():
        inner_grouped_linear = grouped_linear.linear_fn
        for i in range(num_gemms):
            setattr(
                ref_grouped_linear,
                f"weight{i}",
                Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
            )

    outputs = _test_padding_grouped_linear_accuracy(
1776
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
1777
1778
    )
    outputs_ref = _test_padding_grouped_linear_accuracy(
1779
        ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
1780
1781
1782
1783
1784
1785
1786
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1787
1788
1789
1790
1791
1792
1793
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
    reset_rng_states()

    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

1794
    # Placeholders used for graph capture.
1795
1796
1797
1798
    static_input = torch.randn(
        config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
    )
    static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
1799
1800
1801
1802

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
    # Basic training loop.
    def train_step():
        optimizer.zero_grad(set_to_none=False)
        out = block(static_input)
        loss = loss_fn(out, static_target)
        loss.backward()
        optimizer.step()
        return out

    # Warmup steps in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            train_step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture graph.
    g = None
    static_output = None
1823
1824
1825
    if graph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
1826
1827
1828
1829
1830
1831
1832
            static_output = train_step()

    # Run with new data.
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    if graph:
1833
1834
        g.replay()
    else:
1835
        static_output = train_step()
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848

    grads = [static_input.grad]
    for p in block.parameters():
        if p.requires_grad:
            grads.append(p.grad)

    with torch.no_grad():
        output = static_output.clone()
    return output, grads


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1849
@pytest.mark.parametrize("model", ["126m"])
1850
1851
1852
1853
1854
1855
1856
def test_gpt_cuda_graph(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1857
    block_args = (
1858
1859
1860
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
1861
1862
    )
    block_kwargs = dict(
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
1873
    )
1874
1875
1876
1877
1878
    block = TransformerLayer(*block_args, **block_kwargs)
    graphed_block = TransformerLayer(*block_args, **block_kwargs)
    with torch.no_grad():
        for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
            param2.copy_(param1)
1879

1880
1881
1882
1883
    out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
    graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
    params = list(block.parameters())
    graphed_params = list(graphed_block.parameters())
1884

1885
1886
1887
1888
    # Check that results match
    assert_allclose(out, graphed_out, 1e-3)
    assert_allclose(params, graphed_params, 1e-3)
    assert_allclose(grads, graphed_grads, 1e-3)
1889
1890


1891
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
1892
1893
1894
1895
1896
1897
1898
    reset_rng_states()
    FP8GlobalStateManager.reset()

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1899
    with fp8_model_init(enabled=fp8_model_params, recipe=recipe):
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
1915
1916
1917
        )

    te_inp_hidden_states = torch.randn(
1918
1919
1920
1921
1922
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1923
1924
1925
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

1926
    with fp8_autocast(enabled=True, fp8_recipe=recipe):
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1941
@pytest.mark.parametrize("model", ["126m"])
1942
1943
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
1944
1945
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)
1946
1947
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
1948
1949
    if recipe.float8_block_scaling() and not fp8_block_scaling_available:
        pytest.skip(reason_for_no_fp8_block_scaling)
1950
1951
1952

    config = model_configs[model]

1953
1954
    outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
    outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe)
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966

    # Check that results match
    tols = dict(rtol=0.125, atol=0.0675)
    for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            rtol=0.125,
            atol=0.0675,
        )

1967
1968
1969

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1970
@pytest.mark.parametrize("model", ["126m"])
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
def test_transformer_layer_hidden_states_format(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
    block_sbhd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="sbhd",
1997
1998
1999
2000
2001
2002
    )

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
    block_bshd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="bshd",
2018
2019
    )

2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
    torch.manual_seed(0)
    block_thd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="thd",
        self_attn_mask_type="padding_causal",
    )

    for (n1, p1), (n2, p2), (n3, p3) in zip(
        block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
    ):
        assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
2043
2044

    x_sbhd = torch.randn(
2045
2046
2047
2048
2049
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
2050

2051
    x_bshd = x_sbhd.transpose(0, 1).contiguous()
2052
2053
    x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
    x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_sbhd = block_sbhd(x_sbhd)

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_bshd = block_bshd(x_bshd)

2065
2066
2067
    # Check that results match
    torch.testing.assert_close(
        y_bshd,
2068
        y_sbhd.transpose(0, 1).contiguous(),
2069
    )
2070

2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
    # THD is not supported in float32 and on GPUs older than Ampere, skip the test here
    if dtype != torch.float32 and sm_80plus:
        # To make sure forward is also identical (just in case some module decides
        # to act fancy)
        torch.manual_seed(0)
        y_thd = block_thd(
            x_thd,
            cu_seqlens_q=x_thd_cumsum,
            cu_seqlens_kv=x_thd_cumsum,
            max_seqlen_q=config.seq_len,
            max_seqlen_kv=config.seq_len,
        )

        torch.testing.assert_close(
            y_bshd,
            y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
        )

2089
2090
2091
2092
2093
2094
2095
2096

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
2097
2098
2099
2100
2101
2102
2103
2104
2105
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
    reset_rng_states()

    if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
        pytest.skip("FusedAttention and FlashAttention do not support FP32")
    if use_RoPE:
        pytest.skip("KV cache does not support starting positions for RoPE")

2106
2107
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2108
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2109
2110
2111
2112
2113

    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    elif backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2114
2115
    elif backend == "UnfusedAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127

    config = model_configs_inference[model_key]

    S = config.seq_len
    B = bs
    H = config.num_attention_heads
    D = config.hidden_size
    head_size = config.embed
    layer_number = 1

    # Limits the max size of KV-cache
    B_max = B
2128
    S_max = S
2129
2130

    if module == "TransformerLayer":
2131
2132
2133
2134
2135
        model = TransformerLayer(
            hidden_size=D,
            ffn_hidden_size=4 * D,
            num_attention_heads=H,
            attn_input_format=input_format,
2136
2137
            self_attn_mask_type="causal",
            enc_dec_attn_mask_type="causal",
2138
2139
2140
2141
2142
            layer_number=layer_number,
            attention_dropout=0.0,
            params_dtype=dtype,
            device="cuda",
        ).eval()
2143
2144
2145
2146
2147
2148
2149
    else:
        model = (
            MultiheadAttention(
                hidden_size=D,
                num_attention_heads=H,
                qkv_format=input_format,
                layer_number=layer_number,
2150
                attention_dropout=0.0,
2151
                attn_mask_type="causal",
2152
                params_dtype=dtype,
2153
2154
2155
2156
2157
            )
            .cuda()
            .eval()
        )

2158
2159
    inference_params = InferenceParams(
        max_batch_size=B_max,
2160
        max_sequence_length=S_max,
2161
2162
2163
2164
2165
2166
2167
2168
        num_heads_kv=H,
        head_dim_k=head_size,
        dtype=dtype,
        is_paged=is_paged,
        total_num_pages=int(B_max * S_max / 256),
        page_size=256,
    )

2169
2170
2171
2172
2173
2174
2175
2176
2177
    rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")

    input = torch.randn((S, B, D), dtype=dtype, device="cuda")
    if input_format == "bshd":
        input = input.transpose(0, 1).contiguous()

    incremental_output = torch.zeros_like(input)

    # Generate output for the entire sequence
2178
    full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
2179
2180

    # Incrementaly generate outputs using KV-cache
2181
    step_dict = OrderedDict(zip(list(range(B)), [1] * B))
2182
    for i in range(S):
2183
2184
        inference_params.pre_step(step_dict)

2185
        if input_format == "sbhd":
2186
            incremental_input = input[i].view(1, B, D)
2187
        else:
2188
            incremental_input = input[:, i, :].view(B, 1, D)
2189

2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
        seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
        cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
        cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
        cu_seqlens_kv = cu_seqlens_q.clone()

        mask_type = "padding"
        kwargs = {}
        if module == "TransformerLayer":
            kwargs["self_attn_mask_type"] = mask_type
        else:
            kwargs["attn_mask_type"] = mask_type
2201
2202
2203
        line_output = model(
            hidden_states=incremental_input,
            inference_params=inference_params,
2204
            rotary_pos_emb=rotary_freqs if use_RoPE else None,
2205
2206
2207
2208
2209
            **kwargs,
            max_seqlen_q=1,
            max_seqlen_kv=S,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
2210
        )
2211
2212

        if input_format == "sbhd":
2213
            incremental_output[i, :, :] = line_output.view(B, D)
2214
        else:
2215
            incremental_output[:, i, :] = line_output.view(B, D)
2216
2217
2218

    if module == "TransformerLayer":
        atol = {
2219
2220
            torch.float32: 5e-3,
            torch.half: 5e-3,
2221
2222
2223
2224
            torch.bfloat16: 5e-2,
        }
    else:
        atol = {
2225
2226
            torch.float32: 1e-3,
            torch.half: 1e-3,
2227
2228
2229
2230
2231
            torch.bfloat16: 1e-2,
        }

    # Check if the fully generated output matches the one generated incrementally
    assert_allclose(full_output, incremental_output, atol[dtype])
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256


@pytest.mark.parametrize(
    "shape",
    [
        (1, 127, 128, 512),
        (8, 15, 128, 512),
        (8, 1027, 128, 512),
        (16, 10027, 128, 512),
    ],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
    torch.manual_seed(0)
    z, m, k, n = shape

    dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
    m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
    assert m_splits.sum() == m and len(m_splits) == z
    m_splits = m_splits.tolist()

    if layout == "TN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2257
2258
2259
        B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        out = [torch.randn(m, n, dtype=dtype, device="cuda")]  # output
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2260
        grad = False
2261
        single_output = True
2262
2263
    elif layout == "NN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2264
2265
2266
2267
2268
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
        out = [torch.randn(m, k, dtype=dtype, device="cuda")]  # dgrad
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2269
        grad = True
2270
        single_output = True
2271
    else:  # layout == "NT"
2272
2273
2274
2275
        A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
2276
        out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # wgrad
2277
        out_ref = [o.clone() for o in out]
2278
        grad = True
2279
        single_output = False
2280
2281

    for i in range(z):
2282
        general_gemm(
2283
2284
2285
            A[i],
            B[i],
            get_workspace(),
2286
            dtype,
2287
2288
2289
2290
2291
            grad=grad,
            accumulate=accumulate,
            layout=layout,
            out=out_ref[i],
        )
2292
2293
    if single_output:
        out_ref = [torch.cat(out_ref)]
2294

2295
    general_grouped_gemm(
2296
        A,
2297
2298
        B,
        out,
2299
2300
        dtype,
        get_multi_stream_cublas_workspace(),
2301
        m_splits=m_splits,
2302
2303
2304
        grad=grad,
        accumulate=accumulate,
        layout=layout,
2305
        single_output=single_output,
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
    )

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
    "shape",
    [
        (1, 128, 128, 512),
        (8, 1024, 128, 512),
        (16, 4096, 128, 512),
    ],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    z, m, k, n = shape
2328
    m_splits = [m // z] * z
2329
2330
2331
2332
2333
2334
2335
2336

    dtype = torch.bfloat16
    A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
    B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
    out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # output
    out_ref = [o.clone() for o in out]

    # fp8 should be robust enough to this fake scale
2337
2338
    scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze()
    amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
2339

2340
2341
2342
2343
    a_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
2344
2345
            tex.DType.kFloat8E4M3,
        )
2346
        for _ in range(z)
2347
    ]
2348
2349
2350
2351
2352
    b_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
            tex.DType.kFloat8E4M3,
2353
        )
2354
        for _ in range(z)
2355
2356
    ]

2357
2358
2359
2360
2361
2362
    A_fp8 = []
    B_fp8 = []

    for i in range(z):
        A_fp8.append(a_quantizers[i](A[i]))
        B_fp8.append(b_quantizers[i](B[i]))
2363
2364
2365

    # baseline
    for i in range(z):
2366
        general_gemm(
2367
2368
2369
            A_fp8[i],
            B_fp8[i],
            get_workspace(),
2370
            dtype,
2371
2372
2373
            out=out_ref[i],
            accumulate=accumulate,
        )
2374
2375
2376
2377
2378
2379
    general_grouped_gemm(
        A_fp8,
        B_fp8,
        out,
        dtype,
        get_multi_stream_cublas_workspace(),
2380
        m_splits=m_splits,
2381
2382
        accumulate=accumulate,
    )
2383
2384
2385
2386

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435


def test_noncontiguous():
    def _create2modules(m, params):
        mod1 = m(*params)
        mod2 = m(*params)
        for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
            p2.data = p1.data.clone()

        return mod1, mod2

    def _run_module(m, inp):
        out = m(inp)
        out.sum().backward()
        ret = [out]
        if inp.grad is not None:
            ret.append(inp.grad)

        for p in m.parameters():
            if p.requires_grad:
                ret.append(p.grad)
        return ret

    a = torch.randn((128, 256), device="cuda", requires_grad=True)
    a = a.T
    assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."

    b = a.contiguous()

    # LayerNorm
    ln1, ln2 = _create2modules(LayerNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # RMSNorm
    ln1, ln2 = _create2modules(RMSNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # GEMM
    g1, g2 = _create2modules(Linear, [128, 128])
    outT = _run_module(g1, a)
    out = _run_module(g2, b)

    assert_allclose(out, outT, 1e-7)