test_numerics.py 91.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import math
6
import os
7
from typing import Dict, List, Tuple, Optional
8
import pytest
9
import random
10
11
12
13

import torch
import torch.nn as nn
from torch.nn import Parameter
yuguo's avatar
yuguo committed
14
from torch.utils.cpp_extension import IS_HIP_EXTENSION
15

16
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
17
18
19
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
20
21
22
    attention_mask_func,
)
from transformer_engine.pytorch import (
23
24
    autocast,
    quantized_model_init,
25
26
27
28
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
29
    GroupedLinear,
30
31
32
33
    MultiheadAttention,
    RMSNorm,
    TransformerLayer,
    LayerNorm,
34
35
    Fp8Padding,
    Fp8Unpadding,
36
37
    Float8Quantizer,
    Float8CurrentScalingQuantizer,
38
39
40
41
42
43
    MXFP8Quantizer,
    get_device_compute_capability,
    is_fp8_available,
    is_mxfp8_available,
    is_fp8_block_scaling_available,
    is_bf16_available,
44
)
45
from transformer_engine.pytorch import torch_version
46
from transformer_engine.pytorch import checkpoint as te_checkpoint
47
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
48
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
49
from transformer_engine.common import recipe
50
import transformer_engine_torch as tex
51
from utils import ModelConfig, reset_rng_states
52

53

54
# Only run FP8 tests on supported devices.
55
56
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
wenjh's avatar
wenjh committed
57
fp8_block_scaling_available = is_fp8_block_scaling_available(return_reason=True)
58

59
sm_80plus = get_device_compute_capability() >= (8, 0)
60

61
seed = 1234
62
63
# Reset RNG states.
reset_rng_states()
64

65
66
67
68
if torch_version() >= (2, 7, 0):
    torch._dynamo.config.recompile_limit = 16
else:
    torch._dynamo.config.cache_size_limit = 16
69
70
71


model_configs = {
72
73
    "small": ModelConfig(1, 128, 8, 16, num_layers=4),
    "126m": ModelConfig(1, 2048, 12, 64, num_layers=12),
74
}
75
model_configs_inference = {
76
    "126m": ModelConfig(1, 256, 12, 64, num_layers=12),
77
}
78
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
79
80
81
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

82
param_types = [torch.float32, torch.float16]
83
if is_bf16_available():  # bf16 requires sm_80 or higher
84
85
86
87
88
89
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

90
91
92
93
94
95
96
97
98
99
100
101
all_activations = [
    "gelu",
    "geglu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
]
102

103
104
all_normalizations = ["LayerNorm", "RMSNorm"]

105
106
mask_types = ["causal", "no_mask"]

107
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
108
109
110
111
112
113
114
115
116
117
118
119
120

if NVTE_TEST_NVINSPECT_ENABLED:
    # The numerics of all the layers should work the same,
    # when debug=True. I fed them with dummy feature
    # to prevent switching off debug, which can happen if
    # no feature is active.
    import nvdlfw_inspect.api as debug_api

    debug_api.initialize(
        os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
        feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
    )

121
122
123
124
125
126
127
128
129

fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
    fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.Float8CurrentScaling())
    fp8_recipes.append(recipe.DelayedScaling())
130

131
132
133
134
135
use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper
if torch.cuda.get_device_capability() == (9, 0):
    use_cutlass_grouped_gemm.append(True)

136

137
138
139
140
def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


141
142
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
    """Estimated numerical error for a datatype
143

144
    Based on tolerances for torch.testing.assert_close.
145

146
147
148
149
150
151
152
153
154
155
156
    """
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    raise ValueError(f"Unsuppored dtype ({dtype})")


def assert_allclose(
157
    l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
158
) -> bool:
159
160
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
161
    for i, (t1, t2) in enumerate(zip(l1, l2)):
162
        tols = dtype_tols(t2.dtype)
163
164
        if rtol is not None:
            tols["rtol"] = rtol
165
166
        if atol is not None:
            tols["atol"] = atol
167
        result = torch.allclose(t1, t2, **tols)
168
        if not result:
169
            diff = torch.abs(t1 - t2)
170
            tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
171
172
173
174
175
176
177
178
179
180
181
182
            exceed_mask = diff > tol
            if exceed_mask.any():
                indices = torch.nonzero(exceed_mask, as_tuple=True)
                max_diff = diff[exceed_mask].max()
                max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
                max_location = [idx[max_idx].item() for idx in indices]
                msg = (
                    f"Outputs not close enough in tensor at idx={i}. "
                    f"Maximum difference at location {max_location} "
                    f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
                    f"(diff {max_diff.item()})."
                )
183
            raise AssertionError(msg)
184
185


186
187
188
189
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()
190
191


192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs


class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
241
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
280
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
281
282

        # change view [b * np, sq, sk]
283
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

299

300
class TorchLayerNorm(nn.Module):
301
    def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        super().__init__()
        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.bias = nn.Parameter(torch.zeros(in_features))
        self.register_parameter("weight", self.weight)
        self.register_parameter("bias", self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weight if not self.zero_centered_gamma else 1 + self.weight
        w = w.to(torch.float32)
        b = self.bias.to(torch.float32)
        inp = x.to(torch.float32)
318
319
320
        out = torch.nn.functional.layer_norm(
            inp, (self.in_features,), weight=w, bias=b, eps=self.eps
        )
321
322
        return out.to(x.dtype)

323

324
325
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
326
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
327
328
329
330
        super().__init__()

        self.eps = eps
        self.in_features = in_features
331
        self.zero_centered_gamma = zero_centered_gamma
332

333
334
        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
335
336
337
        self.register_parameter("weight", self.weight)

    def forward(self, x):
338
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
339
340
        d_x = self.in_features

341
        rms_x2 = norm_x2 / d_x + self.eps
342
        r_rms_x = rms_x2 ** (-1.0 / 2)
343
        x_normed = x * r_rms_x
344

345
346
347
348
        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)
349

350

351
class TorchLayerNormLinear(nn.Module):
352
353
354
355
356
357
358
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float,
        normalization: str = "LayerNorm",
        zero_centered_gamma: bool = False,
359
        bias: bool = True,
360
    ):
361
        super().__init__()
362
        if normalization == "LayerNorm":
363
364
365
            self.layernorm = TorchLayerNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
366
        elif normalization == "RMSNorm":
367
368
369
            self.layernorm = TorchRMSNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
370
371
372
        else:
            raise RuntimeError("Unsupported normalization")

373
        self.linear = nn.Linear(in_features, out_features, bias=bias)
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.layernorm(x))


class TorchMHA(nn.Module):
    def __init__(self, hidden_size: int, num_attention_heads: int):
        super().__init__()
        self.mhsa = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=0.1,
            bias=True,
            batch_first=False,
        )

390
391
    def forward(self, x, attention_mask=None):
        output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
392
393
394
395
        if isinstance(output, tuple):
            output = output[0]
        return output

396

397
398
399
class TorchQuickGELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(1.702 * input)
400

401

402
403
404
405
class TorchSquaredRELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (input > 0) * input * input

406

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
class TorchGroupedLinearWithPadding(nn.Module):

    def __init__(
        self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
    ) -> None:
        super().__init__()

        self.padding = Fp8Padding(num_gemms)
        self.linear_fn = GroupedLinear(
            num_gemms,
            in_features,
            out_features,
            bias=bias,
            params_dtype=params_dtype,
            parallel_mode=parallel_mode,
            device="cuda",
        )
        self.unpadding = Fp8Unpadding(num_gemms)

        self.fp8 = fp8

    def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
        if self.fp8:
            orig_m_splits = m_splits
            inp, m_splits = self.padding(inp, m_splits)

        out = self.linear_fn(inp, m_splits)

        if self.fp8:
            out = self.unpadding(out, orig_m_splits)

        return out


441
442
_supported_act = {
    "gelu": nn.GELU(approximate="tanh"),
443
    "geglu": nn.GELU(approximate="tanh"),
444
    "qgelu": TorchQuickGELU(),
445
446
447
    "qgeglu": TorchQuickGELU(),
    "relu": nn.ReLU(),
    "reglu": nn.ReLU(),
448
    "srelu": TorchSquaredRELU(),
449
450
451
    "sreglu": TorchSquaredRELU(),
    "silu": nn.SiLU(),
    "swiglu": nn.SiLU(),
452
}
453

454

455
456
457
458
459
460
461
class TorchGLU(nn.Module):
    def __init__(self, activation: str):
        super().__init__()
        self.act = _supported_act[activation]

    def forward(self, x):
        shape = x.size(-1)
462
463
        a = x[..., : shape // 2]
        b = x[..., (shape // 2) :]
464
465
        a = self.act(a)
        return a * b
466

467

468
class TorchLayerNormMLP(nn.Module):
469
470
471
472
473
474
475
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        activation="gelu",
        normalization: str = "LayerNorm",
476
        bias: bool = True,
477
    ):
478
        super().__init__()
479
        if normalization == "LayerNorm":
480
            self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
481
        elif normalization == "RMSNorm":
482
            self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
483
484
        else:
            raise RuntimeError("Unsupported normalization")
485
        if "glu" in activation:
486
487
488
489
490
491
            fc1_output_features = 2 * ffn_hidden_size
            self.gelu = TorchGLU(activation)
        else:
            fc1_output_features = ffn_hidden_size
            self.gelu = _supported_act[activation]

492
493
        self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
        self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
494
495

    def forward(self, x):
496
497
        t = self.gelu(self.fc1(self.ln(x)))
        return self.fc2(t)
498
499
500


class TorchGPT(nn.Module):
501
502
503
    def __init__(
        self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
    ):
504
        super().__init__()
505
        self.ln = nn.LayerNorm(hidden_size, eps=eps)
506
        self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
507
        self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
508
        self.parallel_attention_mlp = parallel_attention_mlp
509
510
511
512

    def forward(
        self,
        x: torch.Tensor,
513
        attention_mask: Optional[torch.Tensor] = None,
514
    ) -> torch.Tensor:
515
        a = self.ln(x)
516
        b = self.causal_attn(a, attention_mask)
517
518
519
520
521
522
523
        if self.parallel_attention_mlp:
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
        else:
            x = x + nn.functional.dropout(b, p=0.1, training=self.training)
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(n, p=0.1, training=self.training)
524
525
526
        return x


527
528
529
def _test_e2e_selective_recompute(
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False
):
530
    reset_rng_states()
531
    FP8GlobalStateManager.reset()
532
533
534
535
536

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

537
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
538
539
540
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
541
            config.num_heads,
542
543
544
545
546
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
547
            kv_channels=config.kv_channels,
548
549
550
551
552
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
553
554
555
        )

    te_inp_hidden_states = torch.randn(
556
        (config.max_seqlen_q, bs, config.hidden_size),
557
558
559
560
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
561
    te_inp_hidden_states.retain_grad()
562
    te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
563

564
    with autocast(enabled=fp8, recipe=recipe):
565
566
        te_out = block(
            te_inp_hidden_states,
567
            attention_mask=te_inp_attn_mask,
568
            checkpoint_core_attention=recompute,
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
583
@pytest.mark.parametrize("model", ["126m"])
584
@pytest.mark.parametrize("fp8", all_boolean)
585
@pytest.mark.parametrize("recipe", fp8_recipes)
586
@pytest.mark.parametrize("fp8_model_params", all_boolean)
587
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
588
589
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
590
591
592
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

593

594
595
    config = model_configs[model]

596
    outputs = _test_e2e_selective_recompute(
597
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False
598
599
    )
    outputs_recompute = _test_e2e_selective_recompute(
600
        bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True
601
    )
602
603
604
605
606
607
608

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-4
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
609

610
611
612
613
614
615
616
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
617
618


619
def _test_e2e_full_recompute(
620
    bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True
621
):
622
623
624
    reset_rng_states()
    FP8GlobalStateManager.reset()

625
626
627
628
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

629
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
630
        block = TransformerLayer(
631
632
            config.hidden_size,
            4 * config.hidden_size,
633
            config.num_heads,
634
635
636
637
638
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
639
            kv_channels=config.kv_channels,
640
641
642
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
643
            fuse_qkv_params=True,
644
            device="cuda",
645
        )
646

647
    te_inp_hidden_states = torch.randn(
648
        (config.max_seqlen_q, bs, config.hidden_size),
649
650
651
652
        dtype=dtype,
        device="cuda",
        requires_grad=use_reentrant,
    )
653
654
    if use_reentrant:
        te_inp_hidden_states.retain_grad()
655
    te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
656

657
    with autocast(enabled=fp8, recipe=recipe):
658
659
660
661
662
663
        if recompute:
            te_out = te_checkpoint(
                block,
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
664
665
666
                distribute_saved_activations=False,
                tp_group=None,
                use_reentrant=use_reentrant,
667
668
669
670
671
672
673
674
675
676
677
            )
        else:
            te_out = block(
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

678
679
680
681
682
683
    outputs = [te_out]
    names = ["output"]
    if use_reentrant:
        outputs.append(te_inp_hidden_states.grad)
        names.append("input")
    for name, p in block.named_parameters():
684
685
        if p.requires_grad:
            outputs.append(p.grad)
686
687
688
            names.append(name)

    return outputs, names
689
690
691
692


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
693
@pytest.mark.parametrize("model", ["126m"])
694
@pytest.mark.parametrize("fp8", all_boolean)
695
@pytest.mark.parametrize("recipe", fp8_recipes)
696
@pytest.mark.parametrize("fp8_model_params", all_boolean)
697
@pytest.mark.parametrize("use_reentrant", all_boolean)
698
699
700
def test_gpt_full_activation_recompute(
    dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
701
702
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
703
704
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
705

706
707
708

    config = model_configs[model]

709
710
711
712
    if not use_reentrant:
        # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
        os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

713
    outputs, names = _test_e2e_full_recompute(
714
715
716
717
718
719
720
721
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=False,
        use_reentrant=use_reentrant,
722
723
    )
    outputs_recompute, _ = _test_e2e_full_recompute(
724
725
726
727
728
729
730
731
        bs,
        dtype,
        config,
        fp8,
        recipe,
        fp8_model_params,
        recompute=True,
        use_reentrant=use_reentrant,
732
    )
733
734
735
736
737

    if not use_reentrant:
        # Reset bias+GELU fusion flag to avoid contaminating other tests
        del os.environ["NVTE_BIAS_GELU_NVFUSION"]

738
739
740
741
742
743
744
745
746
747
748
749
750
    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-3
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
751
752
753
754
755
756


def _test_e2e_checkpointing_get_model(config, dtype):
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
757

758
759
760
    return TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
761
        config.num_heads,
762
763
764
765
766
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
767
        kv_channels=config.kv_channels,
768
769
770
771
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        params_dtype=dtype,
        device="cuda",
772
773
774
775
776
777
778
    )


def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
    reset_rng_states()

    te_inp_hidden_states = torch.randn(
779
        (config.max_seqlen_q, bs, config.hidden_size),
780
781
782
783
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
784
785
786
787
788
789
790
    te_inp_hidden_states.retain_grad()

    block = _test_e2e_checkpointing_get_model(config, dtype)

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
791
            None,
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
        )
        loss = te_out.sum()
        loss.backward()

    if checkpoint:
        # This process is necessary so that we can start afresh with
        # a new model while erasing all internal state to ensure that
        # loading from a checkpoint gives bitwise identical results.
        # Since gradients are being accumulated, it is important to
        # restore them post loading the checkpoint.
        torch.save(block.state_dict(), path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

809
810
811
        _cpu_rng_state = torch.get_rng_state()
        _cuda_rng_state = torch.cuda.get_rng_state()

812
813
        del block
        block = _test_e2e_checkpointing_get_model(config, dtype)
814
        block.load_state_dict(torch.load(path, weights_only=False))
815
816
        torch.set_rng_state(_cpu_rng_state)
        torch.cuda.set_rng_state(_cuda_rng_state)
817
818
819
820
821
822
823
824
825
826

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
827
            None,
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
        )
        loss = te_out.sum()
        loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
846
@pytest.mark.parametrize("model", ["126m"])
847
848
849
def test_gpt_checkpointing(dtype, bs, model):
    config = model_configs[model]
    outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
850
    outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
851
852
853
854
855
856
857
858
859
860
861
862

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
863
864
865
866
867
868


def _test_e2e_gpt_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
869
        (config.max_seqlen_q, bs, config.hidden_size),
870
871
872
873
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
874
    inp_hidden_states.retain_grad()
875
    inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
876

877
    out = block(inp_hidden_states, attention_mask=inp_attn_mask)
878
879
880
881
882
883
884
885
886
887
888
889
890
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
891
@pytest.mark.parametrize("model", ["small"])
892
893
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
894
895
    config = model_configs[model]

896
897
898
    te_gpt = TransformerLayer(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
899
        num_attention_heads=config.num_heads,
900
901
902
903
904
905
906
907
908
        layernorm_epsilon=config.eps,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        params_dtype=dtype,
        fuse_qkv_params=True,
        qkv_weight_interleaved=False,
        parallel_attention_mlp=parallel_attention_mlp,
        device="cuda",
    ).eval()
909
910
911
912
913

    torch_gpt = (
        TorchGPT(
            config.hidden_size,
            config.eps,
914
            config.num_heads,
915
            parallel_attention_mlp=parallel_attention_mlp,
916
917
918
919
920
921
922
923
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
924
        torch_gpt.ln.weight = Parameter(
925
926
            te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
        )
927
        torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
928
929
930
931
932
933
934
935
936
937
938
939
        torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
            te_gpt.self_attention.layernorm_qkv.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
            te_gpt.self_attention.layernorm_qkv.bias.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
            te_gpt.self_attention.proj.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
            te_gpt.self_attention.proj.bias.clone()
        )
940
941
942
943
944
945
        torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
        torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
        torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
        torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
        torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
        torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
946
947
948
949

    te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
    torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

950
951
952
953
954
955
    atol = {
        torch.float32: 5e-3,
        torch.half: 5e-2,
        torch.bfloat16: 1e-1,
    }

956
    # Check output.
957
958
959
960
961
962
963
964
965
966
967
968
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

    # Check gradients, only for small model
    if model == "small":
        atol[torch.float32] = 5e-2
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
969
970


971
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
972
973
974
    reset_rng_states()

    inp_hidden_states = torch.randn(
975
        (config.max_seqlen_q, bs, config.hidden_size),
976
977
978
979
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
980
    inp_hidden_states.retain_grad()
981
    inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None
982

983
984
985
986
987
988
    forward_kwargs = {}
    if te:
        forward_kwargs["attn_mask_type"] = mask_type
    forward_kwargs["attention_mask"] = inp_attn_mask

    out = block(inp_hidden_states, **forward_kwargs)
989
990
991
992
993
994
995
996
997
998
999
1000
1001
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1002
@pytest.mark.parametrize("model", ["small"])
1003
1004
1005
1006
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
    config = model_configs[model]

1007
1008
    te_mha = MultiheadAttention(
        config.hidden_size,
1009
        config.num_heads,
1010
1011
1012
1013
1014
1015
        fuse_qkv_params=True,
        params_dtype=dtype,
        qkv_weight_interleaved=False,
        input_layernorm=False,
        device="cuda",
    ).eval()
1016
1017
1018
1019

    torch_mha = (
        TorchMHA(
            config.hidden_size,
1020
            config.num_heads,
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
        torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
        torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
        torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

1034
1035
    te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
    torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
1036
1037
1038
1039
1040
1041
1042

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    # Check gradients, only for small model
    if model == "small":
        atol = {
            torch.float32: 5e-2,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-2,
            torch.half: 1e-2,
            torch.bfloat16: 1e-2,
        }
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1058

1059
def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, recipe=None):
1060
    reset_rng_states()
1061
1062
1063
    fp8 = recipe is not None
    if fp8:
        FP8GlobalStateManager.reset()
1064
1065

    inp_hidden_states = torch.randn(
1066
        (config.max_seqlen_q, bs, config.hidden_size),
1067
1068
1069
1070
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1071
1072
    inp_hidden_states.retain_grad()

1073
    with autocast(enabled=fp8, recipe=recipe):
1074
1075
1076
        out = block(inp_hidden_states)
        if isinstance(out, (List, Tuple)):
            out = out[0]
1077
1078
    loss = out.sum()
    loss.backward()
1079
1080
    if delay_wgrad_compute:
        block.backward_dw()
1081
1082
1083
1084
1085

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
1086
1087
1088
1089
1090
            if getattr(p, "main_grad", None) is not None:
                outputs.append(p.main_grad)
                assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
            else:
                outputs.append(p.grad)
1091
1092
1093
    return outputs


1094
1095
1096
def _test_dpa_accuracy(block, bs, dtype, config):
    reset_rng_states()

1097
    mask = torch.triu(
1098
1099
        torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"),
        diagonal=1,
1100
    )
1101
    query, key, value = [
1102
        torch.randn(
1103
            (config.max_seqlen_q, bs, config.num_heads, config.kv_channels),
1104
1105
1106
1107
1108
1109
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        for _ in range(3)
    ]
1110
1111
1112
1113
1114

    query.retain_grad()
    key.retain_grad()
    value.retain_grad()

1115
    out = block(query, key, value, attention_mask=mask)
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()

    return [out, query.grad, key.grad, value.grad]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1126
@pytest.mark.parametrize("model", ["126m"])
1127
1128
1129
1130
1131
def test_dpa_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_dpa = (
        DotProductAttention(
1132
1133
            config.num_heads,
            config.kv_channels,
1134
            attention_dropout=0.0,  # disable dropout, FU uses rng differently
1135
1136
1137
1138
1139
1140
1141
        )
        .to(dtype=dtype)
        .cuda()
    )

    torch_dpa = (
        TorchDotProductAttention(
1142
            config.kv_channels,
1143
            0.0,  # dropout
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        )
        .to(dtype=dtype)
        .cuda()
    )

    te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
    torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1158
1159
1160
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)

1161

1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
class TestReturnBiasModule(nn.Module):
    def __init__(self, mod, **kwargs):
        super().__init__()
        self.te_module = mod(**kwargs)
        self.return_bias = kwargs["return_bias"]
        self.bias = kwargs["bias"]

    def forward(self, x):
        if self.return_bias:
            out, bias = self.te_module(x)
            if self.bias:
                out = out + bias
            return out
        return self.te_module(x)


1178
1179
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1180
@pytest.mark.parametrize("model", ["small"])
1181
1182
1183
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
1184
1185
    config = model_configs[model]

1186
1187
1188
1189
    te_linear = TestReturnBiasModule(
        Linear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
1190
        params_dtype=dtype,
1191
1192
        return_bias=return_bias,
        bias=bias,
1193
        device="cuda",
1194
    )
1195

1196
1197
1198
    torch_linear = torch.nn.Linear(
        config.hidden_size,
        4 * config.hidden_size,
1199
        bias=bias,
1200
1201
        device="cuda",
        dtype=dtype,
1202
    )
1203
1204
1205

    # Share params
    with torch.no_grad():
1206
1207
1208
        torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
        if bias:
            torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
1209
1210
1211
1212
1213

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

    # Check output.
1214
1215
1216
1217
1218
1219
1220
1221
1222
    if model == "small":
        tolerance = 5e-3 if dtype == torch.float32 else 5e-2
        rtol = {
            torch.float32: 1.3e-6,
            torch.half: 1e-2,
            torch.bfloat16: 2e-2,
        }
        for te_output, torch_output in zip(te_outputs, torch_outputs):
            assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
1223

1224

1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
    config = model_configs[model]

    te_linear_ref = Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=False,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    te_linear = Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=True,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    # Share params
    with torch.no_grad():
        te_linear_ref.weight = Parameter(te_linear.weight.clone())
        if bias:
            te_linear_ref.bias = Parameter(te_linear.bias.clone())
        if fuse_wgrad_accumulation:
            weight = getattr(te_linear, f"weight")
            weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
            te_linear_ref.weight.main_grad = weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True)
    te_outputs_ref = _test_granular_accuracy(
        te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
    )

1268
1269
    # Should be bit-wise match
    for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
1270
1271
1272
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1273
1274
1275
1276
1277
1278
1279
1280
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
def test_linear_accuracy_save_original_input(dtype, model, recipe):
    bs = 1
    fuse_wgrad_accumulation = True
    fp8_model_params = False
    fp8 = recipe is not None
1281
1282
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1283
1284
1285
1286
    if fp8 and recipe.delayed():
        pytest.skip("DelayedScaling recipe is not supported with save_original_input")

    config = model_configs[model]
1287
    if config.max_seqlen_q % 16 != 0 and fp8:
1288
1289
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1290
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        te_linear_ref = Linear(
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            device="cuda",
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            save_original_input=False,
        ).eval()

        te_linear = Linear(
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            device="cuda",
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            save_original_input=True,
        ).eval()

    # Share params
    with torch.no_grad():
        te_linear_ref.weight = Parameter(te_linear.weight.clone())
        if fuse_wgrad_accumulation:
            weight = getattr(te_linear, f"weight")
            weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
            te_linear_ref.weight.main_grad = weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
    te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1327
1328
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1329
@pytest.mark.parametrize("model", ["126m"])
1330
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
1331
1332
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
1333
1334
    config = model_configs[model]

1335
1336
1337
1338
1339
1340
1341
    te_rmsnorm = RMSNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1342
1343

    torch_rmsnorm = (
1344
        TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())

    te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

1357
1358
1359
1360
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1361
    }
1362
1363

    # Check output.
1364
1365
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
    atol[torch.float32] = 2e-3
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1376

1377
1378
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1379
@pytest.mark.parametrize("model", ["126m"])
1380
1381
1382
1383
1384
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
    config = model_configs[model]

1385
1386
1387
1388
1389
1390
1391
    te_layernorm = LayerNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1392
1393

    torch_layernorm = (
1394
        TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
        torch_layernorm.bias = Parameter(te_layernorm.bias.clone())

    te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

1408
1409
1410
1411
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1412
    }
1413
1414

    # Check output.
1415
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
1416

1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
    rtol = {
        torch.float32: 1.3e-6,
        torch.half: 1e-3,
        torch.bfloat16: 1.6e-2,
    }
    atol[torch.float32] = 1e-4
    # Check gradients
    for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
        assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1427

1428
1429
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1430
@pytest.mark.parametrize("model", ["small"])
1431
@pytest.mark.parametrize("normalization", all_normalizations)
1432
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1433
1434
1435
1436
1437
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
    dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
1438
1439
    config = model_configs[model]

1440
1441
1442
1443
1444
    te_ln_linear = TestReturnBiasModule(
        LayerNormLinear,
        in_features=config.hidden_size,
        out_features=4 * config.hidden_size,
        eps=config.eps,
1445
1446
1447
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
1448
1449
        return_bias=return_bias,
        bias=bias,
1450
        device="cuda",
1451
    )
1452
1453
1454
1455
1456
1457

    torch_ln_linear = (
        TorchLayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
1458
            normalization=normalization,
1459
            zero_centered_gamma=zero_centered_gamma,
1460
            bias=bias,
1461
1462
1463
1464
1465
1466
1467
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1468
1469
1470
        torch_ln_linear.layernorm.weight = Parameter(
            te_ln_linear.te_module.layer_norm_weight.clone()
        )
1471
        if normalization != "RMSNorm":
1472
1473
1474
1475
1476
1477
            torch_ln_linear.layernorm.bias = Parameter(
                te_ln_linear.te_module.layer_norm_bias.clone()
            )
        torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
        if bias:
            torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
1478
1479
1480
1481

    te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

1482
1483
1484
1485
    atol = {
        torch.float32: 2.5e-4,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1486
    }
1487
1488
1489
1490
1491
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }
1492
1493

    # Check output.
1494
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1495

1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
    if model == "small":
        atol = {
            torch.float32: 1e-3,
            torch.half: 5e-2,
            torch.bfloat16: 5e-2,
        }
        rtol = {
            torch.float32: 1e-3,
            torch.half: 4e-2,
            torch.bfloat16: 4e-2,
        }
        # Check gradients
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])

1511

1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_linear_accuracy_delay_wgrad_compute(
    dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
    config = model_configs[model]

    ln_linear_ref = LayerNormLinear(
        config.hidden_size,
        4 * config.hidden_size,
        config.eps,
        bias=bias,
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
        delay_wgrad_compute=False,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    ln_linear = LayerNormLinear(
        config.hidden_size,
        4 * config.hidden_size,
        config.eps,
        bias=bias,
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
        delay_wgrad_compute=True,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    # Share params
    with torch.no_grad():
        ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone())
        if normalization != "RMSNorm":
            ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone())
        ln_linear_ref.weight = Parameter(ln_linear.weight.clone())
        if bias:
            ln_linear_ref.bias = Parameter(ln_linear.bias.clone())
        if fuse_wgrad_accumulation:
            weight = getattr(ln_linear, f"weight")
            weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
            ln_linear_ref.weight.main_grad = weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True)
    te_outputs_ref = _test_granular_accuracy(
        ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1573
1574
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
1575
@pytest.mark.parametrize("model", ["small"])
1576
@pytest.mark.parametrize("activation", all_activations)
1577
@pytest.mark.parametrize("normalization", all_normalizations)
1578
1579
1580
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
1581
1582
1583
    # Reset RNG state at test start to ensure deterministic model initialization
    reset_rng_states()
    
1584
1585
    config = model_configs[model]

1586
1587
1588
1589
    te_ln_mlp = TestReturnBiasModule(
        LayerNormMLP,
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
1590
1591
1592
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
1593
1594
        return_bias=return_bias,
        bias=bias,
1595
        device="cuda",
1596
    )
1597
1598
1599
1600
1601

    torch_ln_mlp = (
        TorchLayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
1602
            activation=activation,
1603
            normalization=normalization,
1604
            bias=bias,
1605
1606
1607
1608
1609
1610
1611
        )
        .to(dtype=dtype)
        .cuda()
    )

    # Share params
    with torch.no_grad():
1612
        torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
1613
        if normalization != "RMSNorm":
1614
1615
1616
1617
1618
1619
            torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
        torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
        torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
        if bias:
            torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
            torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
1620
1621
1622
1623

    te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

1624
1625
1626
1627
1628
1629
    atol = {
        torch.float32: 2e-2,
        torch.half: 5e-2,
        torch.bfloat16: 5e-2,
    }

1630
1631
1632
1633
1634
1635
    rtol = {
        torch.float32: 1e-3,
        torch.half: 4e-2,
        torch.bfloat16: 4e-2,
    }

1636
    # Check output.
1637
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype])
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649

    # Check gradients, only for small model
    rtol = {
        torch.float32: 1e-3,
        torch.half: 1e-2,
        torch.bfloat16: 4e-2,
    }
    atol[torch.half] = 2e-1
    atol[torch.bfloat16] = 2e-1
    if model == "small":
        for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
            assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
1650
1651


1652
@pytest.mark.parametrize("dtype", param_types)
1653
@pytest.mark.parametrize("bs", [2])
1654
1655
1656
1657
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
1658
    dtype, bs, model, bias, fuse_wgrad_accumulation
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
):
    config = model_configs[model]

    ln_mlp = LayerNormMLP(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        eps=config.eps,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=True,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    ln_mlp_ref = LayerNormMLP(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        eps=config.eps,
        bias=bias,
        params_dtype=dtype,
        device="cuda",
        delay_wgrad_compute=False,
        fuse_wgrad_accumulation=fuse_wgrad_accumulation,
    ).eval()

    # Share params
    with torch.no_grad():
        ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
1687
        ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
        ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
        ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
        if bias:
            ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
            ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
        if fuse_wgrad_accumulation:
            ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32)
            ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone()
            ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32)
            ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone()

    te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True)
    te_outputs_ref = _test_granular_accuracy(
        ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1709
def _test_grouped_linear_accuracy(
1710
1711
1712
1713
1714
1715
1716
1717
1718
    block,
    num_gemms,
    bs,
    dtype,
    config,
    recipe,
    fp8,
    fuse_wgrad_accumulation,
    delay_wgrad_compute=False,
1719
):
1720
1721
1722
1723
1724
    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
1725
        (config.max_seqlen_q, bs, config.hidden_size),
1726
1727
1728
1729
1730
1731
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

1732
    if num_gemms > 1:
1733
1734
        split_size = 1
        if fp8:
1735
            split_size = 16
1736
1737
            if recipe.mxfp8():
                split_size = 128
1738
        m = config.max_seqlen_q // split_size
1739
1740
1741
        dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
        dist.append(dist[-1])  # Manually add a zero
        m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
1742
        m_splits = m_splits * split_size
1743
        assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms
1744
    else:
1745
        m_splits = torch.tensor([config.max_seqlen_q])
1746

1747
    with autocast(enabled=fp8, recipe=recipe):
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
        if isinstance(block, GroupedLinear):
            m_splits = m_splits * bs
            out = block(inp_hidden_states, m_splits.tolist())
        else:
            out = torch.cat(
                [
                    block[i](inp)
                    for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
                ]
            )
    loss = out.sum()
    loss.backward()
1760
1761
1762
1763
1764
1765
    if delay_wgrad_compute:
        if isinstance(block, GroupedLinear):
            block.backward_dw()
        else:
            for i in range(num_gemms):
                block[i].backward_dw()
1766
1767
1768
1769
1770

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
1771
1772
1773
1774
1775
            if getattr(p, "main_grad", None) is not None:
                outputs.append(p.main_grad)
                assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
            else:
                outputs.append(p.grad)
1776
1777
1778
    return outputs


1779
@pytest.mark.parametrize("dtype", param_types, ids=str)
1780
1781
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
1782
@pytest.mark.parametrize("model", ["126m"])
1783
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
1784
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1785
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
1786
1787
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
1788
def test_grouped_linear_accuracy(
1789
1790
1791
1792
1793
1794
1795
    dtype,
    num_gemms,
    bs,
    model,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
1796
1797
    bias,
    delay_wgrad_compute,
1798
    parallel_mode=None,
1799
    use_cutlass=False,
1800
):
1801
    fp8 = recipe is not None
1802
1803
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1804
    if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
1805
        pytest.skip("FP8 parameters are not supported in debug mode.")
1806
1807

    config = model_configs[model]
1808
    if config.max_seqlen_q % 16 != 0 and fp8:
1809
1810
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1811
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1812
1813
1814
1815
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
1816
            bias=bias,
1817
            params_dtype=dtype,
1818
            parallel_mode=parallel_mode,
1819
            device="cuda",
1820
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1821
            delay_wgrad_compute=delay_wgrad_compute,
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
            save_original_input=False,
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
                    bias=bias,
                    params_dtype=dtype,
                    parallel_mode=parallel_mode,
                    device="cuda",
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
            if bias:
                sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
            if fuse_wgrad_accumulation:
                weight_i = getattr(grouped_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
                sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
1849
1850
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
    outputs_ref = _test_grouped_linear_accuracy(
        sequential_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
    )
    outputs = _test_grouped_linear_accuracy(
        grouped_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
    )
1873
1874
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
    for o, o_ref in zip(outputs, outputs_ref):
        if use_cutlass:
            torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
        else:
            # cuBLAS implementation should be bit-wise match
            torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.skipif(
    torch.cuda.get_device_capability() != (9, 0),
    reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy_cutlass(
    dtype,
    num_gemms,
    bs,
    model,
    fuse_wgrad_accumulation,
    delay_wgrad_compute,
):
    os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
    test_grouped_linear_accuracy(
        dtype,
        num_gemms,
        bs,
        model,
        None,
        False,
        fuse_wgrad_accumulation,
        False,
        delay_wgrad_compute,
        None,
        use_cutlass=True,
    )
    os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939


@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", [False])
@pytest.mark.parametrize("fuse_wgrad_accumulation", [True])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("delay_wgrad_compute", [True])
def test_grouped_linear_accuracy_save_original_input(
    dtype,
    num_gemms,
    bs,
    model,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
    bias,
    delay_wgrad_compute,
    parallel_mode=None,
):
    fp8 = recipe is not None
1940
    if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
1941
1942
1943
        pytest.skip("FP8 parameters are not supported in debug mode.")
    if fp8 and recipe.delayed():
        pytest.skip("DelayedScaling recipe is not supported with save_original_input")
1944
1945
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1946
1947

    config = model_configs[model]
1948
    if config.max_seqlen_q % 16 != 0 and fp8:
1949
1950
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

1951
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=bias,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            delay_wgrad_compute=delay_wgrad_compute,
            save_original_input=True,
1963
1964
1965
1966
1967
1968
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
1969
                    bias=bias,
1970
                    params_dtype=dtype,
1971
                    parallel_mode=parallel_mode,
1972
                    device="cuda",
1973
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
1974
1975
1976
1977
1978
1979
1980
1981
1982
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
1983
1984
            if bias:
                sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
1985
1986
1987
1988
            if fuse_wgrad_accumulation:
                weight_i = getattr(grouped_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
                sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
1989

1990
1991
1992
1993
    # Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
    
1994
    outputs_ref = _test_grouped_linear_accuracy(
1995
1996
1997
1998
1999
2000
2001
2002
2003
        sequential_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
2004
2005
    )
    outputs = _test_grouped_linear_accuracy(
2006
2007
2008
2009
2010
2011
2012
2013
2014
        grouped_linear,
        num_gemms,
        bs,
        dtype,
        config,
        recipe,
        fp8,
        fuse_wgrad_accumulation,
        delay_wgrad_compute,
2015
    )
2016
2017
    if IS_HIP_EXTENSION:
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
2018
2019
2020
2021
2022
2023

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


2024
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
2025
def test_grouped_linear_accuracy_single_gemm(recipe):
2026
2027
2028
2029
2030
    """Split the tests to save CI time"""
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=1,
        bs=2,
2031
        model="126m",
2032
        recipe=recipe,
2033
        fp8_model_params=True,
2034
        fuse_wgrad_accumulation=True,
2035
2036
        bias=True,
        delay_wgrad_compute=False,
2037
2038
2039
    )


2040
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
2041
2042

    def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
2043
2044
2045
        align_size = 16
        if recipe.mxfp8():
            align_size = 32
2046
        padded_tokens_per_expert = [
2047
2048
            (num_tokens + align_size - 1) // align_size * align_size
            for num_tokens in tokens_per_expert
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
        ]
        hidden_states = torch.split(hidden_states, tokens_per_expert)
        padded_hidden_states = []
        for hidden_state, actual_num_tokens, padded_num_tokens in zip(
            hidden_states, tokens_per_expert, padded_tokens_per_expert
        ):
            padded_hidden_states.append(hidden_state)
            if padded_num_tokens > actual_num_tokens:
                pad_tensor = torch.zeros(
                    padded_num_tokens - actual_num_tokens,
                    hidden_state.shape[1],
                    dtype=hidden_state.dtype,
                    device=hidden_state.device,
                )
                padded_hidden_states.append(pad_tensor)
        padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
        return padded_hidden_states, padded_tokens_per_expert

    def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
        inputmats = torch.split(
            padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
        )
        hidden_states = torch.cat(
            [
                grad_output_mat[: actual_tokens_per_expert[i]]
                for i, grad_output_mat in enumerate(inputmats)
            ],
            dim=0,
        )

        return hidden_states

    def _generate_random_numbers(n, total_sum):
        if n <= 0:
            return []

        # reset seed
        random.seed(seed)

        breaks = sorted(random.sample(range(1, total_sum), n - 1))
        random_numbers = (
            [breaks[0]]
            + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
            + [total_sum - breaks[-1]]
        )

        return random_numbers

    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
2102
        (config.max_seqlen_q * bs, config.hidden_size),
2103
2104
2105
2106
2107
2108
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

2109
    m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
2110

2111
    with autocast(enabled=fp8, recipe=recipe):
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
        if isinstance(block, TorchGroupedLinearWithPadding):
            out = block(inp_hidden_states, m_splits)
        else:
            if fp8:
                padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
                    inp_hidden_states, m_splits
                )
                padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
                out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
            else:
                out = block(inp_hidden_states, m_splits)

    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
2138
@pytest.mark.parametrize("model", ["126m"])
2139
@pytest.mark.parametrize("fp8", [True])
2140
@pytest.mark.parametrize("recipe", fp8_recipes)
2141
2142
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
2143
2144
2145
2146
2147
2148
2149
2150
2151
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    parallel_mode=None,
):
2152
2153
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
2154
2155
2156
2157
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")

    config = model_configs[model]
2158
    if config.max_seqlen_q % 16 != 0 and fp8:
2159
2160
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

2161
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
        grouped_linear = TorchGroupedLinearWithPadding(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            fp8=fp8,
        ).eval()

2172
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
        ref_grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
            save_original_input=False,
        ).eval()

    # Share params
    with torch.no_grad():
        inner_grouped_linear = grouped_linear.linear_fn
        for i in range(num_gemms):
            setattr(
                ref_grouped_linear,
                f"weight{i}",
                Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
            )

    outputs = _test_padding_grouped_linear_accuracy(
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
    )
    outputs_ref = _test_padding_grouped_linear_accuracy(
        ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", [False])
def test_padding_grouped_linear_accuracy_save_original_input(
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    parallel_mode=None,
2222
):
2223
2224
    if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
2225
2226
    if fp8 and recipe.delayed():
        pytest.skip("DelayedScaling recipe is not supported with save_original_input")
2227
2228
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
2229
2230

    config = model_configs[model]
2231
    if config.max_seqlen_q % 16 != 0 and fp8:
2232
2233
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

2234
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
        grouped_linear = TorchGroupedLinearWithPadding(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            fp8=fp8,
        ).eval()

2245
    with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
2246
2247
2248
2249
2250
2251
2252
2253
        ref_grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
2254
            save_original_input=True,
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
        ).eval()

    # Share params
    with torch.no_grad():
        inner_grouped_linear = grouped_linear.linear_fn
        for i in range(num_gemms):
            setattr(
                ref_grouped_linear,
                f"weight{i}",
                Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
            )

    outputs = _test_padding_grouped_linear_accuracy(
2268
        grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
2269
2270
    )
    outputs_ref = _test_padding_grouped_linear_accuracy(
2271
        ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
2272
2273
2274
2275
2276
2277
2278
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


2279
2280
2281
2282
2283
2284
2285
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
    reset_rng_states()

    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

2286
    # Placeholders used for graph capture.
2287
    static_input = torch.randn(
2288
2289
2290
2291
        config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
    )
    static_target = torch.randn(
        config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype
2292
    )
2293
2294
2295
2296

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
    # Basic training loop.
    def train_step():
        optimizer.zero_grad(set_to_none=False)
        out = block(static_input)
        loss = loss_fn(out, static_target)
        loss.backward()
        optimizer.step()
        return out

    # Warmup steps in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            train_step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture graph.
    g = None
    static_output = None
2317
2318
2319
    if graph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
2320
2321
2322
2323
2324
2325
2326
            static_output = train_step()

    # Run with new data.
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    if graph:
2327
2328
        g.replay()
    else:
2329
        static_output = train_step()
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342

    grads = [static_input.grad]
    for p in block.parameters():
        if p.requires_grad:
            grads.append(p.grad)

    with torch.no_grad():
        output = static_output.clone()
    return output, grads


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
2343
@pytest.mark.parametrize("model", ["126m"])
2344
def test_gpt_cuda_graph(dtype, bs, model):
2345
2346
    if NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("Cuda Graphs are not supported in debug mode.")
2347
2348
2349
2350
2351
2352
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

2353
    block_args = (
2354
2355
        config.hidden_size,
        4 * config.hidden_size,
2356
        config.num_heads,
2357
2358
    )
    block_kwargs = dict(
2359
2360
2361
2362
2363
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
2364
        kv_channels=config.kv_channels,
2365
2366
2367
2368
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
2369
    )
2370
2371
2372
2373
2374
    block = TransformerLayer(*block_args, **block_kwargs)
    graphed_block = TransformerLayer(*block_args, **block_kwargs)
    with torch.no_grad():
        for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
            param2.copy_(param1)
2375

2376
2377
2378
2379
    out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
    graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
    params = list(block.parameters())
    graphed_params = list(graphed_block.parameters())
2380

2381
2382
2383
2384
    # Check that results match
    assert_allclose(out, graphed_out, 1e-3)
    assert_allclose(params, graphed_params, 1e-3)
    assert_allclose(grads, graphed_grads, 1e-3)
2385
2386


2387
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
2388
2389
2390
2391
2392
2393
2394
    reset_rng_states()
    FP8GlobalStateManager.reset()

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

2395
    with quantized_model_init(enabled=fp8_model_params, recipe=recipe):
2396
2397
2398
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
2399
            config.num_heads,
2400
2401
2402
2403
2404
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
2405
            kv_channels=config.kv_channels,
2406
2407
2408
2409
2410
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
2411
2412
2413
        )

    te_inp_hidden_states = torch.randn(
2414
        (config.max_seqlen_q, bs, config.hidden_size),
2415
2416
2417
2418
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
2419
    te_inp_hidden_states.retain_grad()
2420
    te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
2421

2422
    with autocast(enabled=True, recipe=recipe):
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
2437
@pytest.mark.parametrize("model", ["126m"])
2438
2439
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
2440
2441
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)
2442
2443
    if NVTE_TEST_NVINSPECT_ENABLED:
        pytest.skip("FP8 parameters are not supported in debug mode.")
2444
2445
2446

    config = model_configs[model]

2447
2448
    outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
    outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe)
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460

    # Check that results match
    tols = dict(rtol=0.125, atol=0.0675)
    for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            rtol=0.125,
            atol=0.0675,
        )

2461
2462
2463

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
2464
@pytest.mark.parametrize("model", ["126m"])
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
def test_transformer_layer_hidden_states_format(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
2476
2477
2478
    block_sbhd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
2479
        config.num_heads,
2480
2481
2482
2483
2484
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
2485
        kv_channels=config.kv_channels,
2486
2487
2488
2489
2490
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="sbhd",
2491
2492
2493
2494
2495
2496
    )

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
2497
2498
2499
    block_bshd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
2500
        config.num_heads,
2501
2502
2503
2504
2505
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
2506
        kv_channels=config.kv_channels,
2507
2508
2509
2510
2511
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="bshd",
2512
2513
    )

2514
2515
2516
2517
    torch.manual_seed(0)
    block_thd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
2518
        config.num_heads,
2519
2520
2521
2522
2523
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
2524
        kv_channels=config.kv_channels,
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="thd",
        self_attn_mask_type="padding_causal",
    )

    for (n1, p1), (n2, p2), (n3, p3) in zip(
        block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
    ):
        assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
2537
2538

    x_sbhd = torch.randn(
2539
        (config.max_seqlen_q, bs, config.hidden_size),
2540
2541
2542
2543
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
2544

2545
    x_bshd = x_sbhd.transpose(0, 1).contiguous()
2546
2547
    x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous()
    x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_sbhd = block_sbhd(x_sbhd)

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_bshd = block_bshd(x_bshd)

2559
2560
2561
    # Check that results match
    torch.testing.assert_close(
        y_bshd,
2562
        y_sbhd.transpose(0, 1).contiguous(),
2563
    )
2564

2565
2566
2567
2568
2569
2570
2571
2572
2573
    # THD is not supported in float32 and on GPUs older than Ampere, skip the test here
    if dtype != torch.float32 and sm_80plus:
        # To make sure forward is also identical (just in case some module decides
        # to act fancy)
        torch.manual_seed(0)
        y_thd = block_thd(
            x_thd,
            cu_seqlens_q=x_thd_cumsum,
            cu_seqlens_kv=x_thd_cumsum,
2574
2575
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
2576
2577
2578
2579
        )

        torch.testing.assert_close(
            y_bshd,
2580
            y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
2581
        )
2582

2583
2584
2585
2586
2587
2588
2589
2590
2591
2592

@pytest.mark.parametrize(
    "shape",
    [
        (1, 127, 128, 512),
        (8, 15, 128, 512),
        (8, 1027, 128, 512),
        (16, 10027, 128, 512),
    ],
)
2593
@pytest.mark.parametrize("dtype", param_types, ids=str)
2594
2595
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
2596
2597
@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm)
def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
    torch.manual_seed(0)
    z, m, k, n = shape

    dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
    m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
    assert m_splits.sum() == m and len(m_splits) == z
    m_splits = m_splits.tolist()

    if layout == "TN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2608
2609
2610
        B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        out = [torch.randn(m, n, dtype=dtype, device="cuda")]  # output
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2611
        grad = False
2612
        single_output = True
2613
2614
    elif layout == "NN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
2615
2616
2617
2618
2619
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
        out = [torch.randn(m, k, dtype=dtype, device="cuda")]  # dgrad
        out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
2620
        grad = True
2621
        single_output = True
2622
    else:  # layout == "NT"
2623
2624
2625
2626
        A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits))  # input
        B = list(
            torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
        )  # grad_output
2627
        out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # wgrad
2628
        out_ref = [o.clone() for o in out]
2629
        grad = True
2630
        single_output = False
2631

2632
2633
2634
    if use_cutlass:
        os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"

2635
2636
    # Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
    if IS_HIP_EXTENSION:
2637
        ori_force_rocm_gemm = os.environ.get("NVTE_FORCE_ROCM_GEMM", None)
2638
2639
        os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"

2640
    for i in range(z):
2641
        general_gemm(
2642
2643
2644
            A[i],
            B[i],
            get_workspace(),
2645
            dtype,
2646
2647
2648
2649
2650
            grad=grad,
            accumulate=accumulate,
            layout=layout,
            out=out_ref[i],
        )
2651
2652
    if single_output:
        out_ref = [torch.cat(out_ref)]
2653

2654
    general_grouped_gemm(
2655
        A,
2656
2657
        B,
        out,
2658
2659
        dtype,
        get_multi_stream_cublas_workspace(),
2660
        m_splits=m_splits,
2661
2662
2663
        grad=grad,
        accumulate=accumulate,
        layout=layout,
2664
        single_output=single_output,
2665
    )
2666
    if IS_HIP_EXTENSION:
2667
2668
2669
2670
        if ori_force_rocm_gemm is not None:
            os.environ["NVTE_FORCE_ROCM_GEMM"] = ori_force_rocm_gemm
        else:
            del os.environ["NVTE_FORCE_ROCM_GEMM"]
2671
2672

    for o, o_ref in zip(out, out_ref):
2673
2674
2675
2676
2677
2678
2679
2680
        if not use_cutlass:
            # cublas implementation should be bit-wise match
            torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
        else:
            torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

    if use_cutlass:
        os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
2681
2682


2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
    "input_quantizer",
    [
        Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
        MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
    ],
)
@pytest.mark.parametrize(
    "out_quantizer",
    [
        Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
        MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
        Float8Quantizer(
            torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3
        ),
    ],
)
def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer):
    # For MXFP8 and CurrentScaling, below unfused quantization should happen
    # FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output
    # Skip invalid configurations
    is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance(
        out_quantizer, MXFP8Quantizer
    )
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if is_mxfp8_needed and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
    weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
    outp_type = torch.float32
    quantized_out, *_ = general_gemm(
        weight_fp8,
        inp_fp8,
        get_workspace(),
        outp_type,
        quantization_params=out_quantizer,
        bias=None,
        use_split_accumulator=False,
    )

    out, *_ = general_gemm(
        weight_fp8,
        inp_fp8,
        get_workspace(),
        outp_type,
        quantization_params=None,
        bias=None,
        use_split_accumulator=False,
    )
    expected_quantized_out = out_quantizer(out)

    # Match results again Pytorch GEMM and allow for quantization tolerance
    pytorch_out = torch.matmul(
        inp_fp8.dequantize().to(torch.float64),
        torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1),
    )
    fp8_tols = dict(rtol=0.125, atol=0.0675)
    torch.testing.assert_close(
        pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols
    )
    # Match results between quantization happening inside vs outside general_gemm
    torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize())
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758


@pytest.mark.parametrize(
    "shape",
    [
        (1, 128, 128, 512),
        (8, 1024, 128, 512),
        (16, 4096, 128, 512),
    ],
)
@pytest.mark.parametrize("accumulate", [False, True])
2759
def test_fp8_grouped_gemm(shape, accumulate):
2760
2761
2762
2763
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    z, m, k, n = shape
2764
    m_splits = [m // z] * z
2765
2766
2767
2768
2769
2770
2771
2772

    dtype = torch.bfloat16
    A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
    B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
    out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # output
    out_ref = [o.clone() for o in out]

    # fp8 should be robust enough to this fake scale
2773
2774
    scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze()
    amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
2775

2776
2777
2778
2779
    a_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
2780
2781
            tex.DType.kFloat8E4M3,
        )
2782
        for _ in range(z)
2783
    ]
2784
2785
2786
2787
2788
    b_quantizers = [
        Float8Quantizer(
            scale.clone(),
            amax.clone(),
            tex.DType.kFloat8E4M3,
2789
        )
2790
        for _ in range(z)
2791
2792
    ]

2793
2794
2795
2796
2797
2798
    A_fp8 = []
    B_fp8 = []

    for i in range(z):
        A_fp8.append(a_quantizers[i](A[i]))
        B_fp8.append(b_quantizers[i](B[i]))
2799
2800
2801

    # baseline
    for i in range(z):
2802
        general_gemm(
2803
2804
2805
            A_fp8[i],
            B_fp8[i],
            get_workspace(),
2806
            dtype,
2807
2808
2809
            out=out_ref[i],
            accumulate=accumulate,
        )
2810
2811
2812
2813
2814
2815
    general_grouped_gemm(
        A_fp8,
        B_fp8,
        out,
        dtype,
        get_multi_stream_cublas_workspace(),
2816
        m_splits=m_splits,
2817
2818
        accumulate=accumulate,
    )
2819
2820
2821
2822

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871


def test_noncontiguous():
    def _create2modules(m, params):
        mod1 = m(*params)
        mod2 = m(*params)
        for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
            p2.data = p1.data.clone()

        return mod1, mod2

    def _run_module(m, inp):
        out = m(inp)
        out.sum().backward()
        ret = [out]
        if inp.grad is not None:
            ret.append(inp.grad)

        for p in m.parameters():
            if p.requires_grad:
                ret.append(p.grad)
        return ret

    a = torch.randn((128, 256), device="cuda", requires_grad=True)
    a = a.T
    assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."

    b = a.contiguous()

    # LayerNorm
    ln1, ln2 = _create2modules(LayerNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # RMSNorm
    ln1, ln2 = _create2modules(RMSNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # GEMM
    g1, g2 = _create2modules(Linear, [128, 128])
    outT = _run_module(g1, a)
    out = _run_module(g2, b)

    assert_allclose(out, outT, 1e-7)