test_numerics.py 60 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import math
6
import os
7
from typing import Dict, List, Optional
8
import pytest
9
import copy
10
11
12
13
14

import torch
import torch.nn as nn
from torch.nn import Parameter

15
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init
16
17
18
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
19
    attention_mask_func,
20
    is_bf16_compatible,
21
22
)
from transformer_engine.pytorch import (
23
24
25
26
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
27
    GroupedLinear,
28
29
30
31
32
    MultiheadAttention,
    RMSNorm,
    TransformerLayer,
    LayerNorm,
    InferenceParams,
33
34
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
35
36
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
37
from transformer_engine.pytorch.utils import get_device_compute_capability
38
import transformer_engine_torch as tex
39

40
41
42
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

43
sm_80plus = get_device_compute_capability() >= (8, 0)
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()


class ModelConfig:
    def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

67
68
69
70
71
72
73
74
model_configs_inference = {
    # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

75
param_types = [torch.float32, torch.float16]
76
if is_bf16_compatible():  # bf16 requires sm_80 or higher
77
78
79
80
81
82
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

83
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
84

85
86
all_normalizations = ["LayerNorm", "RMSNorm"]

87
88
89
mask_types = ["causal", "no_mask"]


90
91
92
93
def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


94
95
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
    """Estimated numerical error for a datatype
96

97
    Based on tolerances for torch.testing.assert_close.
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    """
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    raise ValueError(f"Unsuppored dtype ({dtype})")


def assert_allclose(
    l1: List[torch.Tensor],
    l2: List[torch.Tensor],
    atol: float,
) -> bool:
114
115
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
116
    for i, (t1, t2) in enumerate(zip(l1, l2)):
117
118
119
120
        result = torch.allclose(t1, t2, atol=atol)
        if not result:
            diff = torch.abs(t1 - t2).flatten()
            m = torch.argmax(diff)
121
122
123
124
125
            msg = (
                f"Outputs not close enough in tensor at idx={i}. "
                f"Location of the maximum difference: {m.item()} "
                f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
                f"(diff {diff[m].item()})."
126
127
            )
            raise AssertionError(msg)
128
129
130


def reset_rng_states() -> None:
131
    """revert back to initial RNG state."""
132
    torch.set_rng_state(_cpu_rng_state)
133
134
135
136
137
138
139
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()
140
141


142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs


class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
191
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
230
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
231
232

        # change view [b * np, sq, sk]
233
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

249

250
class TorchLayerNorm(nn.Module):
251
    def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        super().__init__()
        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.bias = nn.Parameter(torch.zeros(in_features))
        self.register_parameter("weight", self.weight)
        self.register_parameter("bias", self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.weight if not self.zero_centered_gamma else 1 + self.weight
        w = w.to(torch.float32)
        b = self.bias.to(torch.float32)
        inp = x.to(torch.float32)
268
269
270
        out = torch.nn.functional.layer_norm(
            inp, (self.in_features,), weight=w, bias=b, eps=self.eps
        )
271
272
        return out.to(x.dtype)

273

274
275
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
276
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
277
278
279
280
        super().__init__()

        self.eps = eps
        self.in_features = in_features
281
        self.zero_centered_gamma = zero_centered_gamma
282

283
284
        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
285
286
287
        self.register_parameter("weight", self.weight)

    def forward(self, x):
288
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
289
290
        d_x = self.in_features

291
        rms_x2 = norm_x2 / d_x + self.eps
292
        r_rms_x = rms_x2 ** (-1.0 / 2)
293
        x_normed = x * r_rms_x
294

295
296
297
298
        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)
299

300

301
class TorchLayerNormLinear(nn.Module):
302
303
304
305
306
307
308
309
310
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float,
        bias: bool = True,
        normalization: str = "LayerNorm",
        zero_centered_gamma: bool = False,
    ):
311
        super().__init__()
312
        if normalization == "LayerNorm":
313
314
315
            self.layernorm = TorchLayerNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
316
        elif normalization == "RMSNorm":
317
318
319
            self.layernorm = TorchRMSNorm(
                in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
            )
320
321
322
        else:
            raise RuntimeError("Unsupported normalization")

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.layernorm(x))


class TorchMHA(nn.Module):
    def __init__(self, hidden_size: int, num_attention_heads: int):
        super().__init__()
        self.mhsa = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=0.1,
            bias=True,
            batch_first=False,
        )

340
341
    def forward(self, x, attention_mask=None):
        output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
342
343
344
345
        if isinstance(output, tuple):
            output = output[0]
        return output

346

347
348
349
class TorchQuickGELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(1.702 * input)
350

351

352
353
354
355
class TorchSquaredRELU(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (input > 0) * input * input

356
357
358
359
360
361
362
363
364
365

_supported_act = {
    "geglu": nn.GELU(approximate="tanh"),
    "gelu": nn.GELU(approximate="tanh"),
    "reglu": nn.ReLU(),
    "relu": nn.ReLU(),
    "swiglu": nn.SiLU(),
    "qgelu": TorchQuickGELU(),
    "srelu": TorchSquaredRELU(),
}
366

367

368
369
370
371
372
373
374
class TorchGLU(nn.Module):
    def __init__(self, activation: str):
        super().__init__()
        self.act = _supported_act[activation]

    def forward(self, x):
        shape = x.size(-1)
375
376
        a = x[..., : shape // 2]
        b = x[..., (shape // 2) :]
377
378
        a = self.act(a)
        return a * b
379

380

381
class TorchLayerNormMLP(nn.Module):
382
383
384
385
386
387
388
389
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        activation="gelu",
        normalization: str = "LayerNorm",
    ):
390
        super().__init__()
391
        if normalization == "LayerNorm":
392
            self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
393
        elif normalization == "RMSNorm":
394
            self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
395
396
        else:
            raise RuntimeError("Unsupported normalization")
397
        if "glu" in activation:
398
399
400
401
402
403
404
            fc1_output_features = 2 * ffn_hidden_size
            self.gelu = TorchGLU(activation)
        else:
            fc1_output_features = ffn_hidden_size
            self.gelu = _supported_act[activation]

        self.fc1 = nn.Linear(hidden_size, fc1_output_features)
405
        self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
406
407

    def forward(self, x):
408
        return self.fc2(self.gelu(self.fc1(self.ln(x))))
409
410
411


class TorchGPT(nn.Module):
412
413
414
    def __init__(
        self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
    ):
415
        super().__init__()
416
        self.ln = nn.LayerNorm(hidden_size, eps=eps)
417
        self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
418
        self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
419
        self.parallel_attention_mlp = parallel_attention_mlp
420
421
422
423

    def forward(
        self,
        x: torch.Tensor,
424
        attention_mask: Optional[torch.Tensor] = None,
425
    ) -> torch.Tensor:
426
        a = self.ln(x)
427
        b = self.causal_attn(a, attention_mask)
428
429
430
431
432
433
434
        if self.parallel_attention_mlp:
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
        else:
            x = x + nn.functional.dropout(b, p=0.1, training=self.training)
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(n, p=0.1, training=self.training)
435
436
437
        return x


438
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
439
    reset_rng_states()
440
    FP8GlobalStateManager.reset()
441
442
443
444
445

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

446
    with fp8_model_init(enabled=fp8 and fp8_model_params):
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
462
463
464
        )

    te_inp_hidden_states = torch.randn(
465
466
467
468
469
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
470
471
472
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

473
    with fp8_autocast(enabled=fp8):
474
475
        te_out = block(
            te_inp_hidden_states,
476
            attention_mask=te_inp_attn_mask,
477
            checkpoint_core_attention=recompute,
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
493
@pytest.mark.parametrize("fp8", all_boolean)
494
495
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
496
497
498
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

499
500
    config = model_configs[model]

501
502
503
504
505
506
    outputs = _test_e2e_selective_recompute(
        bs, dtype, config, fp8, fp8_model_params, recompute=False
    )
    outputs_recompute = _test_e2e_selective_recompute(
        bs, dtype, config, fp8, fp8_model_params, recompute=True
    )
507
508
509
510
511
512
513
514
515
516
517
518
519
520

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-4
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
521
522


523
def _test_e2e_full_recompute(
524
    bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True
525
):
526
527
528
    reset_rng_states()
    FP8GlobalStateManager.reset()

529
530
531
532
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

533
    with fp8_model_init(enabled=fp8 and fp8_model_params):
534
        block = TransformerLayer(
535
536
537
538
539
540
541
542
543
544
545
546
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
547
            fuse_qkv_params=True,
548
            device="cuda",
549
        )
550

551
    te_inp_hidden_states = torch.randn(
552
553
554
555
556
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=use_reentrant,
    )
557
558
    if use_reentrant:
        te_inp_hidden_states.retain_grad()
559
560
561
562
563
564
565
566
567
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

    with fp8_autocast(enabled=fp8):
        if recompute:
            te_out = te_checkpoint(
                block,
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
568
569
570
                distribute_saved_activations=False,
                tp_group=None,
                use_reentrant=use_reentrant,
571
572
573
574
575
576
577
578
579
580
581
            )
        else:
            te_out = block(
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

582
583
584
585
586
587
    outputs = [te_out]
    names = ["output"]
    if use_reentrant:
        outputs.append(te_inp_hidden_states.grad)
        names.append("input")
    for name, p in block.named_parameters():
588
589
        if p.requires_grad:
            outputs.append(p.grad)
590
591
592
            names.append(name)

    return outputs, names
593
594
595
596
597
598


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
599
@pytest.mark.parametrize("fp8_model_params", all_boolean)
600
601
@pytest.mark.parametrize("use_reentrant", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant):
602
603
604
605
606
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

    config = model_configs[model]

607
608
609
610
    if not use_reentrant:
        # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
        os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

611
612
613
614
615
616
    outputs, names = _test_e2e_full_recompute(
        bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant
    )
    outputs_recompute, _ = _test_e2e_full_recompute(
        bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant
    )
617
618
619
620
621

    if not use_reentrant:
        # Reset bias+GELU fusion flag to avoid contaminating other tests
        del os.environ["NVTE_BIAS_GELU_NVFUSION"]

622
623
624
625
626
627
628
629
630
631
632
633
634
    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols["atol"] = 1e-3
    if fp8 or fp8_model_params:
        tols.update(dict(rtol=0.125, atol=0.0675))
    for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
635
636
637
638
639
640


def _test_e2e_checkpointing_get_model(config, dtype):
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
641

642
643
644
645
646
647
648
649
650
651
652
653
654
655
    return TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        params_dtype=dtype,
        device="cuda",
656
657
658
659
660
661
662
    )


def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
    reset_rng_states()

    te_inp_hidden_states = torch.randn(
663
664
665
666
667
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
668
669
670
671
672
673
674
    te_inp_hidden_states.retain_grad()

    block = _test_e2e_checkpointing_get_model(config, dtype)

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
675
            None,
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        )
        loss = te_out.sum()
        loss.backward()

    if checkpoint:
        # This process is necessary so that we can start afresh with
        # a new model while erasing all internal state to ensure that
        # loading from a checkpoint gives bitwise identical results.
        # Since gradients are being accumulated, it is important to
        # restore them post loading the checkpoint.
        torch.save(block.state_dict(), path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

693
694
695
696
        global _cpu_rng_state, _cuda_rng_state
        _cpu_rng_state = torch.get_rng_state()
        _cuda_rng_state = torch.cuda.get_rng_state()

697
698
699
        del block
        block = _test_e2e_checkpointing_get_model(config, dtype)
        block.load_state_dict(torch.load(path))
700
        reset_rng_states()
701
702
703
704
705
706
707
708
709
710

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
711
            None,
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
        )
        loss = te_out.sum()
        loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_checkpointing(dtype, bs, model):
    config = model_configs[model]
    outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
734
    outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
735
736
737
738
739
740
741
742
743
744
745
746

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            **tols,
        )
747
748
749
750
751
752


def _test_e2e_gpt_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
753
754
755
756
757
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
758
759
760
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len)

761
    out = block(inp_hidden_states, attention_mask=inp_attn_mask)
762
763
764
765
766
767
768
769
770
771
772
773
774
775
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
776
777
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
778
779
    config = model_configs[model]

780
781
782
783
784
785
786
787
788
789
790
791
792
    te_gpt = TransformerLayer(
        hidden_size=config.hidden_size,
        ffn_hidden_size=4 * config.hidden_size,
        num_attention_heads=config.num_attention_heads,
        layernorm_epsilon=config.eps,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        params_dtype=dtype,
        fuse_qkv_params=True,
        qkv_weight_interleaved=False,
        parallel_attention_mlp=parallel_attention_mlp,
        device="cuda",
    ).eval()
793
794
795
796
797
798

    torch_gpt = (
        TorchGPT(
            config.hidden_size,
            config.eps,
            config.num_attention_heads,
799
            parallel_attention_mlp=parallel_attention_mlp,
800
801
802
803
804
805
806
807
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
808
        torch_gpt.ln.weight = Parameter(
809
810
            te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
        )
811
        torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
812
813
814
815
816
817
818
819
820
821
822
823
        torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
            te_gpt.self_attention.layernorm_qkv.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
            te_gpt.self_attention.layernorm_qkv.bias.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
            te_gpt.self_attention.proj.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
            te_gpt.self_attention.proj.bias.clone()
        )
824
825
826
827
828
829
        torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
        torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
        torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
        torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
        torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
        torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
830
831
832
833
834
835
836
837
838
839
840

    te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
    torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


841
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
842
843
844
    reset_rng_states()

    inp_hidden_states = torch.randn(
845
846
847
848
849
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
850
851
852
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None

853
854
855
856
857
858
    forward_kwargs = {}
    if te:
        forward_kwargs["attn_mask_type"] = mask_type
    forward_kwargs["attention_mask"] = inp_attn_mask

    out = block(inp_hidden_states, **forward_kwargs)
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
    config = model_configs[model]

877
878
879
880
881
882
883
884
885
    te_mha = MultiheadAttention(
        config.hidden_size,
        config.num_attention_heads,
        fuse_qkv_params=True,
        params_dtype=dtype,
        qkv_weight_interleaved=False,
        input_layernorm=False,
        device="cuda",
    ).eval()
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903

    torch_mha = (
        TorchMHA(
            config.hidden_size,
            config.num_attention_heads,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
        torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
        torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
        torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

904
905
    te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
    torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
906
907
908
909
910
911
912
913

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


914
915
916
917
def _test_granular_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
918
919
920
921
922
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
923
924
925
926
927
928
929
930
931
932
933
934
935
936
    inp_hidden_states.retain_grad()

    out = block(inp_hidden_states)
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


937
938
939
def _test_dpa_accuracy(block, bs, dtype, config):
    reset_rng_states()

940
941
942
    mask = torch.triu(
        torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
    )
943
    query, key, value = [
944
945
946
947
948
949
950
951
        torch.randn(
            (config.seq_len, bs, config.num_attention_heads, config.embed),
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        for _ in range(3)
    ]
952
953
954
955
956

    query.retain_grad()
    key.retain_grad()
    value.retain_grad()

957
    out = block(query, key, value, attention_mask=mask)
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()

    return [out, query.grad, key.grad, value.grad]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_dpa_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_dpa = (
        DotProductAttention(
            config.num_attention_heads,
            config.embed,
976
            attention_dropout=0.0,  # disable dropout, FU uses rng differently
977
978
979
980
981
982
983
984
        )
        .to(dtype=dtype)
        .cuda()
    )

    torch_dpa = (
        TorchDotProductAttention(
            config.embed,
985
            0.0,  # dropout
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
        )
        .to(dtype=dtype)
        .cuda()
    )

    te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
    torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


1001
1002
1003
1004
1005
1006
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_linear_accuracy(dtype, bs, model):
    config = model_configs[model]

1007
1008
1009
1010
1011
1012
1013
    te_linear = Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=True,
        params_dtype=dtype,
        device="cuda",
    ).eval()
1014

1015
1016
1017
1018
1019
1020
1021
    torch_linear = torch.nn.Linear(
        config.hidden_size,
        4 * config.hidden_size,
        bias=True,
        device="cuda",
        dtype=dtype,
    ).eval()
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036

    # Share params
    with torch.no_grad():
        torch_linear.weight = Parameter(te_linear.weight.clone())
        torch_linear.bias = Parameter(te_linear.bias.clone())

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

1037

1038
1039
1040
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
1041
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
1042
1043
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
1044
1045
    config = model_configs[model]

1046
1047
1048
1049
1050
1051
1052
    te_rmsnorm = RMSNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1053
1054

    torch_rmsnorm = (
1055
        TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())

    te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

    # Check output.
1069
1070
1071
1072
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1073
1074
1075
    }
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

1076

1077
1078
1079
1080
1081
1082
1083
1084
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
    config = model_configs[model]

1085
1086
1087
1088
1089
1090
1091
    te_layernorm = LayerNorm(
        config.hidden_size,
        eps=eps,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1092
1093

    torch_layernorm = (
1094
        TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
        torch_layernorm.bias = Parameter(te_layernorm.bias.clone())

    te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

    # Check output.
1109
1110
1111
1112
    atol = {
        torch.float32: 1e-7,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1113
1114
    }
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
1115

1116

1117
1118
1119
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
1120
@pytest.mark.parametrize("normalization", all_normalizations)
1121
1122
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
1123
1124
    config = model_configs[model]

1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
    te_ln_linear = LayerNormLinear(
        config.hidden_size,
        4 * config.hidden_size,
        config.eps,
        bias=True,
        normalization=normalization,
        params_dtype=dtype,
        zero_centered_gamma=zero_centered_gamma,
        device="cuda",
    ).eval()
1135
1136
1137
1138
1139
1140
1141

    torch_ln_linear = (
        TorchLayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
            bias=True,
1142
            normalization=normalization,
1143
            zero_centered_gamma=zero_centered_gamma,
1144
1145
1146
1147
1148
1149
1150
1151
1152
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
1153
1154
        if normalization != "RMSNorm":
            torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
1155
1156
1157
1158
1159
1160
1161
        torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
        torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())

    te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

    # Check output.
1162
1163
1164
1165
    atol = {
        torch.float32: 2.5e-4,
        torch.half: 2e-3,
        torch.bfloat16: 2e-2,
1166
1167
    }
    assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
1168
1169


1170
1171
1172
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
1173
@pytest.mark.parametrize("activation", all_activations)
1174
1175
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
1176
1177
    config = model_configs[model]

1178
1179
1180
1181
1182
1183
1184
1185
    te_ln_mlp = LayerNormMLP(
        config.hidden_size,
        4 * config.hidden_size,
        activation=activation,
        normalization=normalization,
        params_dtype=dtype,
        device="cuda",
    ).eval()
1186
1187
1188
1189
1190

    torch_ln_mlp = (
        TorchLayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
1191
            activation=activation,
1192
            normalization=normalization,
1193
1194
1195
1196
1197
1198
1199
1200
1201
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
1202
1203
        if normalization != "RMSNorm":
            torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
        torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
        torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
        torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())

    te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

    m = config.seq_len // 16
1233
1234
    dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
    dist.append(dist[-1])  # Manually add a zero
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
    m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
    m_splits = m_splits * 16
    assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms

    with fp8_autocast(enabled=fp8):
        if isinstance(block, GroupedLinear):
            m_splits = m_splits * bs
            out = block(inp_hidden_states, m_splits.tolist())
        else:
            out = torch.cat(
                [
                    block[i](inp)
                    for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
                ]
            )
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1267
1268
1269
def test_grouped_linear_accuracy(
    dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None
):
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

    with fp8_model_init(enabled=fp8 and fp8_model_params):
        grouped_linear = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=True,
            params_dtype=dtype,
1284
            parallel_mode=parallel_mode,
1285
1286
1287
1288
1289
1290
1291
1292
1293
            device="cuda",
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
                    bias=True,
                    params_dtype=dtype,
1294
                    parallel_mode=parallel_mode,
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
                    device="cuda",
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms):
            sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
            sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())

    outputs = _test_grouped_linear_accuracy(grouped_linear, num_gemms, bs, dtype, config, fp8)
    outputs_ref = _test_grouped_linear_accuracy(
        sequential_linear, num_gemms, bs, dtype, config, fp8
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
    """Split the tests to reduce CI time"""
    test_grouped_linear_accuracy(
        dtype=torch.float32,
        num_gemms=6,
        bs=2,
        model=list(model_configs.keys())[0],
        fp8=True,
        fp8_model_params=True,
        parallel_mode=parallel_mode,
    )


1331
1332
1333
1334
1335
1336
1337
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
    reset_rng_states()

    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

1338
    # Placeholders used for graph capture.
1339
1340
1341
1342
    static_input = torch.randn(
        config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
    )
    static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
1343
1344
1345
1346

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
    # Basic training loop.
    def train_step():
        optimizer.zero_grad(set_to_none=False)
        out = block(static_input)
        loss = loss_fn(out, static_target)
        loss.backward()
        optimizer.step()
        return out

    # Warmup steps in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            train_step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture graph.
    g = None
    static_output = None
1367
1368
1369
    if graph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
1370
1371
1372
1373
1374
1375
1376
            static_output = train_step()

    # Run with new data.
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    if graph:
1377
1378
        g.replay()
    else:
1379
        static_output = train_step()
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400

    grads = [static_input.grad]
    for p in block.parameters():
        if p.requires_grad:
            grads.append(p.grad)

    with torch.no_grad():
        output = static_output.clone()
    return output, grads


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_cuda_graph(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
1415
1416
1417
    )
    graphed_block = copy.deepcopy(block)

1418
1419
1420
1421
    out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
    graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
    params = list(block.parameters())
    graphed_params = list(graphed_block.parameters())
1422

1423
1424
1425
1426
    # Check that results match
    assert_allclose(out, graphed_out, 1e-3)
    assert_allclose(params, graphed_params, 1e-3)
    assert_allclose(grads, graphed_grads, 1e-3)
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437


def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
    reset_rng_states()
    FP8GlobalStateManager.reset()

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    with fp8_model_init(enabled=fp8_model_params):
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        block = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
            fuse_qkv_params=True,
            device="cuda",
1453
1454
1455
        )

    te_inp_hidden_states = torch.randn(
1456
1457
1458
1459
1460
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

    with fp8_autocast(enabled=True):
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_fp8_parameters(dtype, bs, model):
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    config = model_configs[model]

    outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
    outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499

    # Check that results match
    tols = dict(rtol=0.125, atol=0.0675)
    for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)):
        torch.testing.assert_close(
            test,
            ref,
            msg=f"Mismatch in tensor {i}",
            rtol=0.125,
            atol=0.0675,
        )

1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer_hidden_states_format(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
    block_sbhd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="sbhd",
1530
1531
1532
1533
1534
1535
    )

    # Set `torch.manual_seed` to make sure the weights are identical to the
    # other layer. Set `*dropout` values to 0 to make sure the forward pass
    # is identical to the other layer.
    torch.manual_seed(0)
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
    block_bshd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="bshd",
1551
1552
    )

1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
    torch.manual_seed(0)
    block_thd = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_attention_heads,
        layernorm_epsilon=config.eps,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        hidden_dropout=0,
        attention_dropout=0,
        kv_channels=config.embed,
        params_dtype=dtype,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
        device="cuda",
        attn_input_format="thd",
        self_attn_mask_type="padding_causal",
    )

    for (n1, p1), (n2, p2), (n3, p3) in zip(
        block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
    ):
        assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
1576
1577

    x_sbhd = torch.randn(
1578
1579
1580
1581
1582
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
1583

1584
    x_bshd = x_sbhd.transpose(0, 1).contiguous()
1585
1586
    x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
    x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_sbhd = block_sbhd(x_sbhd)

    # To make sure forward is also identical (just in case some module decides
    # to act fancy)
    torch.manual_seed(0)
    y_bshd = block_bshd(x_bshd)

1598
1599
1600
    # Check that results match
    torch.testing.assert_close(
        y_bshd,
1601
        y_sbhd.transpose(0, 1).contiguous(),
1602
    )
1603

1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
    # THD is not supported in float32 and on GPUs older than Ampere, skip the test here
    if dtype != torch.float32 and sm_80plus:
        # To make sure forward is also identical (just in case some module decides
        # to act fancy)
        torch.manual_seed(0)
        y_thd = block_thd(
            x_thd,
            cu_seqlens_q=x_thd_cumsum,
            cu_seqlens_kv=x_thd_cumsum,
            max_seqlen_q=config.seq_len,
            max_seqlen_kv=config.seq_len,
        )

        torch.testing.assert_close(
            y_bshd,
            y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
        )

1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"

    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    elif backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

    config = model_configs_inference[model_key]

    S = config.seq_len
    B = bs
    H = config.num_attention_heads
    D = config.hidden_size
    head_size = config.embed
    layer_number = 1

    # Limits the max size of KV-cache
    B_max = B
    S_max = S + 2

    if module == "TransformerLayer":
1653
1654
1655
1656
1657
        model = TransformerLayer(
            hidden_size=D,
            ffn_hidden_size=4 * D,
            num_attention_heads=H,
            attn_input_format=input_format,
1658
1659
            self_attn_mask_type="causal_bottom_right",
            enc_dec_attn_mask_type="causal_bottom_right",
1660
1661
1662
1663
1664
            layer_number=layer_number,
            attention_dropout=0.0,
            params_dtype=dtype,
            device="cuda",
        ).eval()
1665
1666
1667
1668
1669
1670
1671
    else:
        model = (
            MultiheadAttention(
                hidden_size=D,
                num_attention_heads=H,
                qkv_format=input_format,
                layer_number=layer_number,
1672
                attention_dropout=0.0,
1673
                attn_mask_type="causal_bottom_right",
1674
                params_dtype=dtype,
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
            )
            .cuda()
            .eval()
        )

    inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
    rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")

    input = torch.randn((S, B, D), dtype=dtype, device="cuda")
    if input_format == "bshd":
        input = input.transpose(0, 1).contiguous()

    incremental_output = torch.zeros_like(input)

    # Generate output for the entire sequence
1690
    full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
1691
1692
1693
1694

    # Incrementaly generate outputs using KV-cache
    for i in range(S):
        if input_format == "sbhd":
1695
            incremental_input = input[i].view(1, B, D)
1696
        else:
1697
            incremental_input = input[:, i, :].view(B, 1, D)
1698
1699
1700
1701

        line_output = model(
            hidden_states=incremental_input,
            inference_params=inference_params,
1702
1703
            rotary_pos_emb=rotary_freqs if use_RoPE else None,
        )
1704
1705
1706
1707

        inference_params.sequence_len_offset += 1

        if input_format == "sbhd":
1708
            incremental_output[i] = line_output.view(B, D)
1709
        else:
1710
            incremental_output[:, i, :] = line_output.view(B, D)
1711
1712
1713

    if module == "TransformerLayer":
        atol = {
1714
1715
            torch.float32: 5e-3,
            torch.half: 5e-3,
1716
1717
1718
1719
            torch.bfloat16: 5e-2,
        }
    else:
        atol = {
1720
1721
            torch.float32: 1e-3,
            torch.half: 1e-3,
1722
1723
1724
1725
1726
            torch.bfloat16: 1e-2,
        }

    # Check if the fully generated output matches the one generated incrementally
    assert_allclose(full_output, incremental_output, atol[dtype])
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880


@pytest.mark.parametrize(
    "shape",
    [
        (1, 127, 128, 512),
        (8, 15, 128, 512),
        (8, 1027, 128, 512),
        (16, 10027, 128, 512),
    ],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
    torch.manual_seed(0)
    z, m, k, n = shape

    dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
    m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
    assert m_splits.sum() == m and len(m_splits) == z
    m_splits = m_splits.tolist()

    if layout == "TN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
        B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
        out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # output
        grad = False
    elif layout == "NN":
        A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
        B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # grad_output
        out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # dgrad
        grad = True
    else:  # layout == "NT"
        A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
        B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # grad_output
        out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # wgrad
        grad = True

    out_ref = [o.clone() for o in out]
    for i in range(z):
        gemm(
            A[i],
            B[i],
            dtype,
            get_workspace(),
            grad=grad,
            accumulate=accumulate,
            layout=layout,
            out=out_ref[i],
        )

    grouped_gemm(
        A,
        B,
        out,
        dtype,
        get_multi_stream_cublas_workspace(),
        grad=grad,
        accumulate=accumulate,
        layout=layout,
    )

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
    "shape",
    [
        (1, 128, 128, 512),
        (8, 1024, 128, 512),
        (16, 4096, 128, 512),
    ],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    z, m, k, n = shape
    m_splits = m // z

    dtype = torch.bfloat16
    A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)]  # weight
    B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)  # input
    out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)  # output
    out_ref = [o.clone() for o in out]

    # fp8 should be robust enough to this fake scale
    scale = 1 + torch.rand(z * 3, dtype=torch.float32, device="cuda")
    scale_inv = 1 / scale
    amax = torch.zeros(1024, z * 3, dtype=torch.float32, device="cuda")

    A_fp8 = [
        torch.ops.tex_ts.cast_to_fp8_ts(
            A[i],
            scale,
            amax,
            scale_inv,
            i,  # fp8 meta tensor index
            tex.DType.kFloat8E4M3,
        )
        for i in range(z)
    ]
    B_fp8 = [
        torch.ops.tex_ts.cast_to_fp8_ts(
            B[i],
            scale,
            amax,
            scale_inv,
            z + i,  # fp8 meta tensor index
            fp8_dtype,
        )
        for i in range(z)
    ]

    fp8_grouped_gemm(
        A_fp8,
        scale_inv,
        0,  # A_offset
        tex.DType.kFloat8E4M3,
        B_fp8,
        scale_inv,
        z,  # B_offset
        fp8_dtype,
        out,
        dtype,
        get_multi_stream_cublas_workspace(),
        accumulate=accumulate,
    )

    # baseline
    for i in range(z):
        fp8_gemm(
            A_fp8[i],
            scale_inv,
            i,
            tex.DType.kFloat8E4M3,
            B_fp8[i],
            scale_inv,
            z + i,
            fp8_dtype,
            dtype,
            get_workspace(),
            out=out_ref[i],
            accumulate=accumulate,
        )

    # should be bit-wise match
    for o, o_ref in zip(out, out_ref):
        torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929


def test_noncontiguous():
    def _create2modules(m, params):
        mod1 = m(*params)
        mod2 = m(*params)
        for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
            p2.data = p1.data.clone()

        return mod1, mod2

    def _run_module(m, inp):
        out = m(inp)
        out.sum().backward()
        ret = [out]
        if inp.grad is not None:
            ret.append(inp.grad)

        for p in m.parameters():
            if p.requires_grad:
                ret.append(p.grad)
        return ret

    a = torch.randn((128, 256), device="cuda", requires_grad=True)
    a = a.T
    assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."

    b = a.contiguous()

    # LayerNorm
    ln1, ln2 = _create2modules(LayerNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # RMSNorm
    ln1, ln2 = _create2modules(RMSNorm, [128])
    outT = _run_module(ln1, a)
    out = _run_module(ln2, b)

    assert_allclose(out, outT, 1e-7)

    # GEMM
    g1, g2 = _create2modules(Linear, [128, 128])
    outT = _run_module(g1, a)
    out = _run_module(g2, b)

    assert_allclose(out, outT, 1e-7)