test_numerics.py 37.9 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
import math
6
7
8
import os
from typing import List, Optional
import pytest
9
import copy
10
11
12
13
14

import torch
import torch.nn as nn
from torch.nn import Parameter

15
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init
16
17
18
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
19
20
21
    attention_mask_func,
)
from transformer_engine.pytorch import (
22
23
    DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
    MultiheadAttention, RMSNorm, TransformerLayer
24
25
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
26
27
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker

28

29
30
31
32
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()


class ModelConfig:
    def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

63
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
64

65
66
all_normalizations = ["LayerNorm", "RMSNorm"]

67
68
69
mask_types = ["causal", "no_mask"]


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
    for t1, t2 in zip(l1, l2):
        assert torch.equal(t1, t2), "Output mismatch."


def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool:
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
    for t1, t2 in zip(l1, l2):
85
86
87
88
89
90
91
92
93
94
        result = torch.allclose(t1, t2, atol=atol)
        if not result:
            diff = torch.abs(t1 - t2).flatten()
            m = torch.argmax(diff)
            msg = (f"Outputs not close enough."
                   f"Location of the maximum difference: {m.item()} "
                   f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
                   f"(diff {diff[m].item()})."
            )
            raise AssertionError(msg)
95
96
97


def reset_rng_states() -> None:
98
    """revert back to initial RNG state."""
99
100
101
102
    torch.set_rng_state(_cpu_rng_state)
    _set_cuda_rng_state(_cuda_rng_state)


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs


class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.reshape(
            output_size[2], output_size[0] * output_size[1], -1
        )
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.reshape(
            value_layer.size(0), output_size[0] * output_size[1], -1
        )

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(
            output_size[0] * output_size[1], output_size[2], -1
        )

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

216

217
218
219
220
221
222
223
224
225
226
227
228
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
    def __init__(self, in_features, eps=1e-5):
        super().__init__()

        self.eps = eps
        self.in_features = in_features

        self.weight = nn.Parameter(torch.ones(in_features))
        self.register_parameter("weight", self.weight)

    def forward(self, x):
229
        norm_x2 = torch.sum(x.float()**2, dim=-1, keepdim=True)
230
231
        d_x = self.in_features

232
233
234
        rms_x2 = norm_x2 / d_x + self.eps
        r_rms_x = rms_x2 ** (-1. / 2)
        x_normed = x * r_rms_x
235

236
        return (self.weight.float() * x_normed).to(x.dtype)
237

238

239
class TorchLayerNormLinear(nn.Module):
240
241
242
    def __init__(self, in_features: int, out_features: int,
                 eps: float, bias: bool = True,
                 normalization: str = "LayerNorm"):
243
        super().__init__()
244
245
246
247
248
249
250
        if normalization == "LayerNorm":
            self.layernorm = nn.LayerNorm(in_features, eps=eps)
        elif normalization == "RMSNorm":
            self.layernorm = TorchRMSNorm(in_features, eps=eps)
        else:
            raise RuntimeError("Unsupported normalization")

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.layernorm(x))


class TorchMHA(nn.Module):
    def __init__(self, hidden_size: int, num_attention_heads: int):
        super().__init__()
        self.mhsa = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_attention_heads,
            dropout=0.1,
            bias=True,
            batch_first=False,
        )

268
269
    def forward(self, x, attention_mask=None):
        output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
270
271
272
273
        if isinstance(output, tuple):
            output = output[0]
        return output

274

275
276
277
278
279
280
_supported_act = {'geglu'  : nn.GELU(approximate="tanh"),
                  'gelu'  : nn.GELU(approximate="tanh"),
                  'reglu'  : nn.ReLU(),
                  'relu'  : nn.ReLU(),
                  'swiglu' : nn.SiLU()}

281

282
283
284
285
286
287
288
289
290
291
292
class TorchGLU(nn.Module):
    def __init__(self, activation: str):
        super().__init__()
        self.act = _supported_act[activation]

    def forward(self, x):
        shape = x.size(-1)
        a = x[..., :shape // 2]
        b = x[..., (shape // 2):]
        a = self.act(a)
        return a * b
293

294

295
class TorchLayerNormMLP(nn.Module):
296
    def __init__(self, hidden_size: int, ffn_hidden_size: int,
297
298
                 eps: float = 1e-5, activation = 'gelu',
                 normalization: str = "LayerNorm"):
299
        super().__init__()
300
301
302
303
304
305
        if normalization == "LayerNorm":
            self.ln = nn.LayerNorm(hidden_size, eps=eps)
        elif normalization == "RMSNorm":
            self.ln = TorchRMSNorm(hidden_size, eps=eps)
        else:
            raise RuntimeError("Unsupported normalization")
306
307
308
309
310
311
312
313
        if 'glu' in activation:
            fc1_output_features = 2 * ffn_hidden_size
            self.gelu = TorchGLU(activation)
        else:
            fc1_output_features = ffn_hidden_size
            self.gelu = _supported_act[activation]

        self.fc1 = nn.Linear(hidden_size, fc1_output_features)
314
        self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
315
316

    def forward(self, x):
317
        return self.fc2(self.gelu(self.fc1(self.ln(x))))
318
319
320


class TorchGPT(nn.Module):
321
    def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool):
322
        super().__init__()
323
        self.ln = nn.LayerNorm(hidden_size, eps=eps)
324
        self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
325
        self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
326
        self.parallel_attention_mlp = parallel_attention_mlp
327
328
329
330
331
332

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
333
        a = self.ln(x)
334
        b = self.causal_attn(a, attn_mask)
335
336
337
338
339
340
341
        if self.parallel_attention_mlp:
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
        else:
            x = x + nn.functional.dropout(b, p=0.1, training=self.training)
            n = self.ln_mlp(x)
            x = x + nn.functional.dropout(n, p=0.1, training=self.training)
342
343
344
        return x


345

346
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
347
    reset_rng_states()
348
    FP8GlobalStateManager.reset()
349
350
351
352
353

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

354
355
356
357
358
359
360
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker():
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    with fp8_model_init(enabled=fp8 and fp8_model_params):
        block = (
            TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_attention_heads,
                layernorm_epsilon=config.eps,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.1,
                attention_dropout=0.1,
                kv_channels=config.embed,
                apply_residual_connection_post_layernorm=False,
                output_layernorm=False,
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
                params_dtype=dtype,
                fuse_qkv_params=True,
            )
            .cuda()
380
381
382
383
384
385
386
387
        )

    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

388
    with fp8_autocast(enabled=fp8):
389
390
        te_out = block(
            te_inp_hidden_states,
391
            attention_mask=te_inp_attn_mask,
392
            checkpoint_core_attention=recompute,
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
408
@pytest.mark.parametrize("fp8", all_boolean)
409
410
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
411
412
413
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

414
415
    config = model_configs[model]

416
417
    outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
    outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
418
419
420
    assert_all_equal(outputs, outputs_recompute)


421
def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
422
423
424
    reset_rng_states()
    FP8GlobalStateManager.reset()

425
426
427
428
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

429
430
431
432
433
434
435
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker():
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

436
437
    with fp8_model_init(enabled=fp8 and fp8_model_params):
        block = (
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
452
            fuse_qkv_params=True,
453
454
        )
        .cuda()
455
        )
456

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

    with fp8_autocast(enabled=fp8):
        if recompute:
            te_out = te_checkpoint(
                block,
                False,  # distribute_saved_activations
                get_dummy_cuda_rng_tracker,
                None,  # tp_group
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
        else:
            te_out = block(
                te_inp_hidden_states,
                attention_mask=te_inp_attn_mask,
                checkpoint_core_attention=False,
            )
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
495
496
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
497
498
499
500
501
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

    config = model_configs[model]

502
503
    outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
    outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
504
505
506
507
508
509
510
    assert_all_equal(outputs, outputs_recompute)


def _test_e2e_checkpointing_get_model(config, dtype):
    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    return (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            params_dtype=dtype,
        )
        .cuda()
    )


def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"):
    reset_rng_states()

    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    te_inp_hidden_states.retain_grad()

    block = _test_e2e_checkpointing_get_model(config, dtype)

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
544
            None,
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        )
        loss = te_out.sum()
        loss.backward()

    if checkpoint:
        # This process is necessary so that we can start afresh with
        # a new model while erasing all internal state to ensure that
        # loading from a checkpoint gives bitwise identical results.
        # Since gradients are being accumulated, it is important to
        # restore them post loading the checkpoint.
        torch.save(block.state_dict(), path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

562
563
564
565
        global _cpu_rng_state, _cuda_rng_state
        _cpu_rng_state = torch.get_rng_state()
        _cuda_rng_state = torch.cuda.get_rng_state()

566
567
568
        del block
        block = _test_e2e_checkpointing_get_model(config, dtype)
        block.load_state_dict(torch.load(path))
569
        reset_rng_states()
570
571
572
573
574
575
576
577
578
579

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for _ in range(steps // 2):
        te_out = block(
            te_inp_hidden_states,
580
            None,
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        )
        loss = te_out.sum()
        loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_checkpointing(dtype, bs, model):
    config = model_configs[model]
    outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
603
604
    outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
    assert_all_equal(outputs, outputs_checkpoint)
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630


def _test_e2e_gpt_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len)

    out = block(inp_hidden_states, inp_attn_mask)
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
631
632
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
633
634
635
636
637
638
639
640
641
642
643
644
    config = model_configs[model]

    te_gpt = (
        TransformerLayer(
            hidden_size=config.hidden_size,
            ffn_hidden_size=4 * config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layernorm_epsilon=config.eps,
            attention_dropout=0.1,
            hidden_dropout=0.1,
            fuse_qkv_params=True,
            qkv_weight_interleaved=False,
645
            parallel_attention_mlp=parallel_attention_mlp,
646
647
648
649
650
651
652
653
654
655
656
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    torch_gpt = (
        TorchGPT(
            config.hidden_size,
            config.eps,
            config.num_attention_heads,
657
            parallel_attention_mlp=parallel_attention_mlp,
658
659
660
661
662
663
664
665
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
666
        torch_gpt.ln.weight = Parameter(
667
668
            te_gpt.self_attention.layernorm_qkv.layer_norm_weight.clone()
        )
669
        torch_gpt.ln.bias = Parameter(te_gpt.self_attention.layernorm_qkv.layer_norm_bias.clone())
670
671
672
673
674
675
676
677
678
679
680
681
        torch_gpt.causal_attn.mhsa.in_proj_weight = Parameter(
            te_gpt.self_attention.layernorm_qkv.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.in_proj_bias = Parameter(
            te_gpt.self_attention.layernorm_qkv.bias.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.weight = Parameter(
            te_gpt.self_attention.proj.weight.clone()
        )
        torch_gpt.causal_attn.mhsa.out_proj.bias = Parameter(
            te_gpt.self_attention.proj.bias.clone()
        )
682
683
684
685
686
687
        torch_gpt.ln_mlp.ln.weight = Parameter(te_gpt.layernorm_mlp.layer_norm_weight.clone())
        torch_gpt.ln_mlp.ln.bias = Parameter(te_gpt.layernorm_mlp.layer_norm_bias.clone())
        torch_gpt.ln_mlp.fc1.weight = Parameter(te_gpt.layernorm_mlp.fc1_weight.clone())
        torch_gpt.ln_mlp.fc1.bias = Parameter(te_gpt.layernorm_mlp.fc1_bias.clone())
        torch_gpt.ln_mlp.fc2.weight = Parameter(te_gpt.layernorm_mlp.fc2_weight.clone())
        torch_gpt.ln_mlp.fc2.bias = Parameter(te_gpt.layernorm_mlp.fc2_bias.clone())
688
689
690
691
692
693
694
695
696
697
698

    te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
    torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


699
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
700
701
702
703
704
705
706
707
    reset_rng_states()

    inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    inp_hidden_states.retain_grad()
    inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None

708
709
710
711
712
713
    forward_kwargs = {}
    if te:
        forward_kwargs["attn_mask_type"] = mask_type
    forward_kwargs["attention_mask"] = inp_attn_mask

    out = block(inp_hidden_states, **forward_kwargs)
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
    config = model_configs[model]

    te_mha = (
        MultiheadAttention(
            config.hidden_size,
            config.num_attention_heads,
            fuse_qkv_params=True,
            qkv_weight_interleaved=False,
            input_layernorm=False,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    torch_mha = (
        TorchMHA(
            config.hidden_size,
            config.num_attention_heads,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
        torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
        torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
        torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

762
763
    te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
    torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)
764
765
766
767
768
769
770
771

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
def _test_granular_accuracy(block, bs, dtype, config):
    reset_rng_states()

    inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    inp_hidden_states.retain_grad()

    out = block(inp_hidden_states)
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


792
793
794
795
796
797
798
799
800
801
802
803
def _test_dpa_accuracy(block, bs, dtype, config):
    reset_rng_states()

    mask = torch.triu(torch.ones(config.seq_len, config.seq_len, device="cuda"), diagonal=1).bool()
    query, key, value = [
        torch.randn(config.seq_len, bs, config.num_attention_heads,
        config.embed, dtype=dtype, requires_grad=True).cuda() for _ in range(3)]

    query.retain_grad()
    key.retain_grad()
    value.retain_grad()

804
    out = block(query, key, value, attention_mask=mask)
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
    loss = out.sum()
    loss.backward()

    torch.cuda.synchronize()

    return [out, query.grad, key.grad, value.grad]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_dpa_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_dpa = (
        DotProductAttention(
            config.num_attention_heads,
            config.embed,
823
            attention_dropout=0.0, # disable dropout, FU uses rng differently
824
825
826
827
828
829
830
831
        )
        .to(dtype=dtype)
        .cuda()
    )

    torch_dpa = (
        TorchDotProductAttention(
            config.embed,
832
            0.0, # dropout
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
        )
        .to(dtype=dtype)
        .cuda()
    )

    te_outputs = _test_dpa_accuracy(te_dpa, bs, dtype, config)
    torch_outputs = _test_dpa_accuracy(torch_dpa, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_linear_accuracy(dtype, bs, model):
    config = model_configs[model]

    te_linear = (
        Linear(
            config.hidden_size,
            4 * config.hidden_size,
            bias=True,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    torch_linear = (
        torch.nn.Linear(
            config.hidden_size,
            4 * config.hidden_size,
            bias=True,
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_linear.weight = Parameter(te_linear.weight.clone())
        torch_linear.bias = Parameter(te_linear.bias.clone())

    te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

890

891
892
893
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
894
895
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
def test_rmsnorm_accuracy(dtype, bs, model, eps):
896
897
898
899
900
    config = model_configs[model]

    te_rmsnorm = (
        RMSNorm(
            config.hidden_size,
901
            eps=eps,
902
903
904
905
906
907
908
909
910
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    torch_rmsnorm = (
        TorchRMSNorm(
            config.hidden_size,
911
            eps=eps,
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())

    te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 1e-7)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 2e-2)
930

931

932
933
934
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
935
936
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
937
938
939
940
941
942
943
944
    config = model_configs[model]

    te_ln_linear = (
        LayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
            bias=True,
945
            normalization=normalization,
946
947
948
949
950
951
952
953
954
955
956
957
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    torch_ln_linear = (
        TorchLayerNormLinear(
            config.hidden_size,
            4 * config.hidden_size,
            config.eps,
            bias=True,
958
            normalization=normalization,
959
960
961
962
963
964
965
966
967
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
968
969
        if normalization != "RMSNorm":
            torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
970
971
972
973
974
975
976
977
978
979
980
        torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
        torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())

    te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
981
982


983
984
985
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
986
@pytest.mark.parametrize("activation", all_activations)
987
988
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
989
990
991
992
993
994
    config = model_configs[model]

    te_ln_mlp = (
        LayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
995
            activation=activation,
996
            normalization=normalization,
997
998
999
1000
1001
1002
1003
1004
1005
1006
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    torch_ln_mlp = (
        TorchLayerNormMLP(
            config.hidden_size,
            4 * config.hidden_size,
1007
            activation=activation,
1008
            normalization=normalization,
1009
1010
1011
1012
1013
1014
1015
1016
1017
        )
        .to(dtype=dtype)
        .cuda()
        .eval()
    )

    # Share params
    with torch.no_grad():
        torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
1018
1019
        if normalization != "RMSNorm":
            torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
        torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
        torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
        torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())

    te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
    torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

    # Check output.
    if dtype == torch.float32:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
    else:
        assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


1035
1036
1037
1038
1039
1040
1041
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
    reset_rng_states()

    # Initialize loss function and optimizer.
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(block.parameters(), lr=0.1)

1042
    # Placeholders used for graph capture.
1043
1044
1045
1046
1047
1048
    static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
    static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)

    real_input = torch.rand_like(static_input)
    real_target = torch.rand_like(static_target)

1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
    # Basic training loop.
    def train_step():
        optimizer.zero_grad(set_to_none=False)
        out = block(static_input)
        loss = loss_fn(out, static_target)
        loss.backward()
        optimizer.step()
        return out

    # Warmup steps in a separate stream.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            train_step()
    torch.cuda.current_stream().wait_stream(s)

    # Capture graph.
    g = None
    static_output = None
1069
1070
1071
    if graph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
1072
1073
1074
1075
1076
1077
1078
            static_output = train_step()

    # Run with new data.
    with torch.no_grad():
        static_input.copy_(real_input)
        static_target.copy_(real_target)
    if graph:
1079
1080
        g.replay()
    else:
1081
        static_output = train_step()
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121

    grads = [static_input.grad]
    for p in block.parameters():
        if p.requires_grad:
            grads.append(p.grad)

    with torch.no_grad():
        output = static_output.clone()
    return output, grads


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_cuda_graph(dtype, bs, model):
    config = model_configs[model]

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
            layernorm_epsilon=config.eps,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            kv_channels=config.embed,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
        )
        .to(dtype=dtype)
        .cuda()
    )
    graphed_block = copy.deepcopy(block)

1122
1123
1124
1125
    out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
    graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
    params = list(block.parameters())
    graphed_params = list(graphed_block.parameters())
1126

1127
1128
1129
1130
    # Check that results match
    assert_allclose(out, graphed_out, 1e-3)
    assert_allclose(params, graphed_params, 1e-3)
    assert_allclose(grads, graphed_grads, 1e-3)
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199


def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
    reset_rng_states()
    FP8GlobalStateManager.reset()

    sigma = 0.023
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker():
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    with fp8_model_init(enabled=fp8_model_params):
        block = (
            TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_attention_heads,
                layernorm_epsilon=config.eps,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.1,
                attention_dropout=0.1,
                kv_channels=config.embed,
                apply_residual_connection_post_layernorm=False,
                output_layernorm=False,
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
                params_dtype=dtype,
                fuse_qkv_params=True,
            )
            .cuda()
        )

    te_inp_hidden_states = torch.randn(
        config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
    ).cuda()
    te_inp_hidden_states.retain_grad()
    te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

    with fp8_autocast(enabled=True):
        te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
    loss = te_out.sum()
    loss.backward()
    torch.cuda.synchronize()

    outputs = [te_out, te_inp_hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)
    return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_fp8_parameters(dtype, bs, model):
    if not fp8_available:
        pytest.skip(reason_for_no_fp8)

    config = model_configs[model]

    outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
    outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
    assert_all_equal(outputs, outputs_fp8_params)