run_numerics.py 31 KB
Newer Older
1
2
#!/usr/bin/python3

3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
5
6
7
#
# See LICENSE for license information.

import argparse
8
9
10
import datetime
import os
import sys
11
12
13
14
15
16
from functools import wraps

import transformer_engine.pytorch as te
import torch
from torch import nn
import torch.distributed as dist
17
import transformer_engine_torch as tex
18
19
20
from transformer_engine.common.recipe import (
    MXFP8BlockScaling,
    DelayedScaling,
21
    Float8CurrentScaling,
22
    Float8BlockScaling,
23
24
25
    Format,
    Recipe,
)
26
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
27
28
29
30
31
32
33
34
from run_layer_with_overlap import _compare_tensors

SEQ_LEN, BATCH_SIZE = 16, 16
HIDDEN_SIZE = 64
NR_HEADS = 4
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
35
36
37
38
39
40
QUANTIZATION = None


# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
41

42
43
44
45
46
47
48
49
50

# Quantization recipe setup
def quantization_recipe() -> Recipe:
    if QUANTIZATION == "fp8":
        return DelayedScaling(
            fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
        )
    if QUANTIZATION == "mxfp8":
        return MXFP8BlockScaling()
51
52
    if QUANTIZATION == "fp8_cs":
        return Float8CurrentScaling()
53
54
    if QUANTIZATION == "fp8_block_scaling":
        return Float8BlockScaling()
55
    return te.fp8.get_default_fp8_recipe()
56
57
58


def main(argv=None, namespace=None):
59
    global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION
60
61
62
63
64
65
66
67
68
69
70
71

    WORLD_RANK = int(os.getenv("RANK", "0"))
    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
    LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
    LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))

    assert WORLD_SIZE == LOCAL_SIZE  # this test supports only 1 node
    assert LOCAL_SIZE <= torch.cuda.device_count()
    dist_init_kwargs = {
        "backend": "nccl",
        "rank": WORLD_RANK,
        "world_size": WORLD_SIZE,
72
        "timeout": datetime.timedelta(seconds=30),
73
74
75
76
77
78
79
80
81
82
83
84
85
    }
    dist_init_kwargs["init_method"] = "env://"
    dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
    assert dist.is_nccl_available()
    torch.cuda.set_device(LOCAL_RANK)
    dist.init_process_group(**dist_init_kwargs)

    NCCL_WORLD = dist.new_group(backend="nccl")

    WORLD_SIZE = dist.get_world_size()

    parser = argparse.ArgumentParser()
    parser.add_argument("-l", "--layer-type", type=str)
86
    parser.add_argument("--quantization", type=str, default=None)
87
88
    args = parser.parse_args(argv, namespace)

89
90
    # Quantization scheme
    QUANTIZATION = args.quantization
91
    if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"):
92
93
94
95
96
        global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
        SEQ_LEN = 32
        BATCH_SIZE = 32
        HIDDEN_SIZE = 128

97
    test_dict = [
98
        test_quantizer,
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        test_linear,
        test_layernorm,
        test_layernorm_linear,
        test_layernorm_mlp,
        test_transformer_layer,
    ]

    for test in test_dict:
        test()
    dist.destroy_process_group()
    return 0


def run_distributed_test(test_name=None):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            name = test_name if test_name is not None else func.__name__

            dist_print(f"Starting test {name} with args {args} and {kwargs}")
            torch.cuda.set_device(WORLD_RANK)
            torch.manual_seed(12345)
            torch.cuda.manual_seed(12345)
            func(*args, **kwargs)

            dist.barrier()
            dist_print(f"Passed test {name}")

        return wrapper

    return decorator


def _gather(tensor, dim=0):
    """
    Gathers tensors and concats them. Since torch.distributed.nn.functional.all_gather
    multiplies gradients by WORLD_SIZE, those gradiedts are rescaled.
    """

    class HalfGradient(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            return input  # forward pass (identity)

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output / WORLD_SIZE  # gradient division by WORLD_SIZE

    tensor = HalfGradient.apply(tensor)
    gathered = torch.distributed.nn.functional.all_gather(tensor, group=NCCL_WORLD)
    return torch.cat(gathered, dim=dim)


def _constant(tensor):
    return nn.init.constant_(tensor, 0.5)


def dist_print(msg, src=None, end="\n", error=False):
    stream = sys.stderr if error else sys.stdout
    if WORLD_RANK == (0 if src is None else src):
        stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")


def _get_tolerances(dtype):
163
164
165
166
167
168
    # loose tolerances for fp8_cs because of sequence parallel & amax reduction
    # so that each rank has a different scale_inv for computing Y when we have
    # row parallel & sequence parallel, because we do the all_gather in backward pass
    if QUANTIZATION == "fp8_cs":
        return {"rtol": 0.4, "atol": 0.25}
    elif QUANTIZATION is not None:
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        return {"rtol": 0.125, "atol": 0.0625}

    if dtype == torch.float16:
        return {"rtol": 1e-3, "atol": 1e-5}
    if dtype == torch.bfloat16:
        return {"rtol": 1.6e-2, "atol": 1e-5}
    if dtype == torch.float32:
        return {"rtol": 1.3e-6, "atol": 1e-5}
    raise ValueError(f"Unsupported dtype ({dtype})")


def _check_outputs(output_single_node, output_distributed):
    numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")

    output_failed, output_info = _compare_tensors(
        "outputs",
        output_distributed,
        output_single_node,
        **_get_tolerances(output_single_node.dtype),
    )
    if output_failed:
        dist_print(output_info, src=WORLD_RANK, error=output_failed)
    numerics_failed[0] = int(output_failed)
    dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD)
193
    assert not bool(numerics_failed.item())
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251


def _match_param_sizes(dist_param, single_param):
    """
    Adjust single_param to match the shape of dist_param
    by slicing along dimensions where the shapes differ.
    This function is typically used in a distributed setting
    where single_param is a larger tensor that needs
    to be partitioned among multiple processes.

    Args:
        dist_param: Tensor representing the distributed output
        with the desired shape for the current process.
        single_param: Tensor representing the non-distributed output,
        possibly larger than dist_param.

    Returns:
        Tensor: Sliced version of single_param matching
        the shape of dist_param for the current process.
    """
    # Initialize indices for slicing with full slices for each dimension
    indices = [slice(None)] * len(single_param.shape)

    # Iterate over each dimension to identify where shapes differ
    for i in range(len(dist_param.shape)):
        if dist_param.shape[i] != single_param.shape[i]:
            # Calculate the start and end indices for slicing based on the world rank
            start = WORLD_RANK * dist_param.shape[i]
            end = (WORLD_RANK + 1) * dist_param.shape[i]
            src_slice = slice(start, end)

            # Update the slicing indices for the current dimension
            indices[i] = src_slice

    # Slice single_param to obtain the output matching dist_param's shape
    to_output = single_param[tuple(indices)]

    return to_output


def _check_gradients(model_distributed, model_single, main_grad_check=False):
    for i, ((name, param_d), param_s) in enumerate(
        zip(model_distributed.named_parameters(), model_single.parameters())
    ):
        numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
        grad_failed, grad_info = None, None
        if main_grad_check:
            param_s_grad = _match_param_sizes(param_d.main_grad, param_s.main_grad)
            grad_failed, grad_info = _compare_tensors(
                str(i), param_d.main_grad, param_s_grad, **_get_tolerances(param_s_grad.dtype)
            )
        else:
            param_s_grad = _match_param_sizes(param_d.grad, param_s.grad)
            grad_failed, grad_info = _compare_tensors(
                str(i), param_d.grad, param_s_grad, **_get_tolerances(param_s_grad.dtype)
            )

        if grad_failed:
252
253
            dist_print(i, src=WORLD_RANK)
            dist_print(name, src=WORLD_RANK)
254
255
256
            dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
        numerics_failed[0] = int(grad_failed)
        dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD)
257
        assert not bool(numerics_failed.item())
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280


def _copy_params(model_distributed, model_single):
    for dist_param, single_param in zip(model_distributed.parameters(), model_single.parameters()):
        with torch.no_grad():
            to_copy = single_param
            for dim, _ in enumerate(dist_param.shape):
                if dist_param.shape[dim] != single_param.shape[dim]:
                    src_slice = slice(
                        WORLD_RANK * dist_param.shape[dim], (WORLD_RANK + 1) * dist_param.shape[dim]
                    )
                    indices = [slice(None)] * max(min(dim, len(dist_param.shape) - 1), 0)
                    indices.append(src_slice)
                    if dim < len(dist_param.shape) - 1:
                        indices.append(slice(None))
                    to_copy = single_param[tuple(indices)]
            dist_param.copy_(to_copy)


def _apply_models(
    model_single_node, model_distributed, input_single_node, input_distributed, **kwargs
):
    _alloc_main_grad(model_single_node, model_distributed)  # for fuse_wgrad_accumulation=True
281
282
283
284
285
286
    input_single_node.requires_grad_()
    input_distributed.requires_grad_()
    with te.fp8_autocast(
        enabled=QUANTIZATION is not None,
        fp8_recipe=quantization_recipe(),
    ):
287
        output_single_node = model_single_node(input_single_node, **kwargs)
288
289
290
291
292
    with te.fp8_autocast(
        enabled=QUANTIZATION is not None,
        fp8_recipe=quantization_recipe(),
        fp8_group=NCCL_WORLD,
    ):
293
294
295
296
297
298
299
300
301
302
        output_distributed = model_distributed(input_distributed, **kwargs)
    return output_single_node, output_distributed


def _loss_backward(output_single_node, output_distributed):
    target = torch.randn_like(output_single_node)
    LOSS_FN(output_single_node, target).backward()
    LOSS_FN(output_distributed, target).backward()


303
304
305
306
307
def _loss_backward_dw(model_single_node, model_distributed):
    model_single_node.backward_dw()
    model_distributed.backward_dw()


308
309
310
311
312
313
def _alloc_main_grad(model_single_node, model_distributed):
    for model in [model_single_node, model_distributed]:
        for param in model.parameters():
            param.main_grad = torch.zeros_like(param, dtype=torch.float32)


314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
###############################################
#                   Quantizer                 #
###############################################
def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
    """
    quantizer is the reference quantizer on a single GPU.
    quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
    """
    if quantizer_class == Float8CurrentScalingQuantizer:
        quantizer_dist = quantizer_class(
            fp8_dtype=fp8_dtype,
            device=device,
            with_amax_reduction=True,
            amax_reduction_group=tp_group,
        )
        quantizer = quantizer_class(
            fp8_dtype=fp8_dtype,
            device=device,
            with_amax_reduction=False,
        )
        return quantizer, quantizer_dist
    else:
        raise ValueError(f"Unsupported quantizer class: {quantizer_class}")


def _shard_tensor(x, world_size, axis):
    split_size = x.size()[axis] // world_size
    split_tensor = torch.split(x, split_size, axis)
    out = []
    for tensor in split_tensor:
        out.append(tensor.detach().clone().requires_grad_(x.requires_grad).cuda())
    return out


@run_distributed_test()
def _test_quantizer(input_dtype, fp8_dtype):
    """Test the quantizer under distributed settings.

    Args:
        input_dtype (torch.dtype): The data type of the input.
        fp8_dtype (tex.DType): The data type of the fp8.
    """

    M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE

    # high precision input
    x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype)
    # set one element of the input to a very large value, which doesn't live in rank 0 after the split
    # to test the amax reduction on purpose
    x_hp_cpu[M - 1, N - 1] = 1e4
    # rank 0 takes the full copy and quantize with GPU 0 for verification
    if WORLD_RANK == 0:
        x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda")
    x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK]

    # Create quantizers
    quantizer, quantizer_dist = _construct_quantizer(
        Float8CurrentScalingQuantizer, fp8_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE
    )

    # quantize the input
    if WORLD_RANK == 0:
        x_fp8_single = quantizer(x_hp_rank0)

    # multi-GPU quantizer
    x_fp8_dist = quantizer_dist(x_hp_local_rank)

    # check scale_inv with zero tolerance
    if WORLD_RANK == 0:
        torch.testing.assert_close(
            x_fp8_single._scale_inv, x_fp8_dist._scale_inv, rtol=0.0, atol=0.0
        )


def test_quantizer():
    """
    Run quantizer tests with various configurations.
    Currently only check fp8_cs because it needs to do amax reduction in the quantizer.
    """
    # skip this test for other quantization schemes
    if QUANTIZATION != "fp8_cs":
        return

    input_dtypes = [torch.float32, torch.bfloat16]
    fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]

    for input_dtype in input_dtypes:
        for fp8_dtype in fp8_dtypes:
            _test_quantizer(input_dtype, fp8_dtype)


405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
############################################
#                   Linear                 #
############################################
@run_distributed_test()
def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
    """Test the linear layer with specified parallel mode and sequence parallelization.

    Args:
        parallel_mode (str): 'row' or 'column' parallelism.
        sequence_parallel (bool): Enable sequence parallelism if True.
        kwargs (dict): Additional arguments for the linear layer.
    """
    # Set parameter data type
    params_dtype = kwargs.get("params_dtype", torch.float32)

    # Create models
    model_single_node = te.Linear(HIDDEN_SIZE, HIDDEN_SIZE, **kwargs)
    model_distributed = te.Linear(
        HIDDEN_SIZE,
        HIDDEN_SIZE,
        tp_size=WORLD_SIZE,
        tp_group=NCCL_WORLD,
        parallel_mode=parallel_mode,
        sequence_parallel=sequence_parallel,
        **kwargs,
    )

    # Synchronize parameters between models
    _copy_params(model_distributed, model_single_node)

    # Prepare input tensors
    input_single_node = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)

    if parallel_mode == "row":
        # Split input across GPUs for row parallelism
        split_size = HIDDEN_SIZE // WORLD_SIZE
        input_distributed = input_single_node[
            :, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size
        ].clone()
    elif parallel_mode == "column":
        if sequence_parallel:
            # Duplicate input for sequence parallelism
            input_single_node = (
                torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
            )
            input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
451
452
453
454
455
            # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
            if QUANTIZATION == "fp8_cs":
                input_distributed = torch.clamp(input_distributed, min=-10, max=10)
                if WORLD_RANK == WORLD_SIZE - 1:
                    input_distributed[BATCH_SIZE - 1, HIDDEN_SIZE - 1] = 11
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
            input_single_node = _gather(input_distributed, dim=0).detach()
        else:
            input_distributed = input_single_node.clone()
    else:
        raise ValueError(f"Invalid parallel_mode: {parallel_mode}")

    # Apply models
    output_single_node, output_distributed = _apply_models(
        model_single_node, model_distributed, input_single_node, input_distributed
    )

    if "return_bias" in kwargs:
        output_single_node, bias_s = output_single_node
        output_distributed, bias_d = output_distributed
        if parallel_mode == "column":
            bias_d = _gather(bias_d)
        _check_outputs(bias_s, bias_d)

    # Gather outputs if necessary
    if parallel_mode == "column" or (sequence_parallel and parallel_mode == "row"):
        output_distributed = _gather(output_distributed, dim=1 if parallel_mode == "column" else 0)

    # Compute loss and backpropagate
    _loss_backward(output_single_node, output_distributed)

481
482
483
484
    # Compute delayed weight gradient
    if "delay_wgrad_compute" in kwargs:
        _loss_backward_dw(model_single_node, model_distributed)

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    # Validate outputs and gradients
    _check_outputs(output_single_node, output_distributed)

    # gradients in other cases need additional synchronization
    if (parallel_mode == "column" or not sequence_parallel) and "return_bias" not in kwargs:
        _check_gradients(
            model_distributed,
            model_single_node,
            main_grad_check=("fuse_wgrad_accumulation" in kwargs),
        )


def test_linear():
    """Run linear layer tests with various configurations."""
    kwargs_list = [
        {},
        {"bias": False},
        {"init_method": _constant},
        {"fuse_wgrad_accumulation": True},
        {"return_bias": True},
        {"params_dtype": torch.float16},
506
        {"delay_wgrad_compute": True},
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    ]
    for kwargs in kwargs_list:
        for parallel_mode in ["column", "row"]:
            for sequence_parallel in [False, True]:
                _test_linear(parallel_mode, sequence_parallel, **kwargs)


############################################
#                 LayerNorm                #
############################################


@run_distributed_test()
def _test_layernorm(kwargs):
    """Test LayerNorm and RMSNorm with given arguments.

    Args:
        kwargs (dict): Contains 'norm', 'basic_args', and 'distributed_args'.
    """
    # Extract parameters
    norm = kwargs["norm"]
    basic_args = kwargs["basic_args"]
    distributed_args = kwargs["distributed_args"]
    params_dtype = basic_args.get("params_dtype", torch.float32)

    # Create models
    model_single_node = norm(HIDDEN_SIZE, **basic_args)
    model_distributed = norm(HIDDEN_SIZE, **{**basic_args, **distributed_args})

    # Synchronize parameters between models
    _copy_params(model_distributed, model_single_node)

    # Prepare input tensors
    input_single_node = torch.randn((BATCH_SIZE, HIDDEN_SIZE), dtype=params_dtype).cuda()
    input_distributed = input_single_node.clone()

    # Apply models
    output_single_node, output_distributed = _apply_models(
        model_single_node, model_distributed, input_single_node, input_distributed
    )

    # Compute loss and backpropagate
    _loss_backward(output_single_node, output_distributed)

    # Validate outputs and gradients
    _check_outputs(output_single_node, output_distributed)
    _check_gradients(model_distributed, model_single_node)


def test_layernorm():
    """Run LayerNorm and RMSNorm tests with various configurations."""
    norms = [te.LayerNorm, te.RMSNorm]

    # Define basic arguments for the models
    basic_args_list = [
        {"zero_centered_gamma": True},
        {"params_dtype": torch.float16},
    ]

    # Define distributed arguments
    distributed_args_list = [
        {},
        {"sequence_parallel": True},
    ]

    # Generate combinations of norms and arguments
    for norm in norms:
        for basic_args in basic_args_list:
            for distributed_args in distributed_args_list:
                kwargs = {
                    "norm": norm,
                    "basic_args": basic_args,
                    "distributed_args": distributed_args,
                }
                _test_layernorm(kwargs)


############################################
#              LayerNormLinear             #
############################################


@run_distributed_test()
def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
    """Test the linear layer with specified parallel mode and sequence parallelization.

    Args:
        parallel_mode (str): 'row' or 'column' parallelism.
        sequence_parallel (bool): Enable sequence parallelism if True.
        kwargs (dict): Additional arguments for the linear layer.
    """
    # Set parameter data type
    params_dtype = kwargs.get("params_dtype", torch.float32)

    # Create models
    model_single_node = te.LayerNormLinear(HIDDEN_SIZE, HIDDEN_SIZE, **kwargs)
    model_distributed = te.LayerNormLinear(
        HIDDEN_SIZE,
        HIDDEN_SIZE,
        tp_size=WORLD_SIZE,
        tp_group=NCCL_WORLD,
        parallel_mode=parallel_mode,
        sequence_parallel=sequence_parallel,
        **kwargs,
    )

    # Synchronize parameters between models
    _copy_params(model_distributed, model_single_node)

    # Prepare input tensors
    input_single_node = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)

    if sequence_parallel:
        # Duplicate input for sequence parallelism
        input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
        input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
623
624
625
626
627
628
        # make the last element of the input a large value to test the amax reduction on purpose
        # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
        if QUANTIZATION == "fp8_cs":
            input_distributed = torch.clamp(input_distributed, min=-10, max=10)
            if WORLD_RANK == WORLD_SIZE - 1:
                input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        input_single_node = _gather(input_distributed).detach()
    else:
        input_distributed = input_single_node.clone()
    # Apply models
    output_single_node, output_distributed = _apply_models(
        model_single_node, model_distributed, input_single_node, input_distributed
    )

    if "return_layernorm_output" in kwargs:
        output_single_node, norm_s = output_single_node
        output_distributed, norm_d = output_distributed
        if sequence_parallel:
            norm_d = _gather(norm_d)
        _check_outputs(norm_s, norm_d)

    if "return_bias" in kwargs:
        output_single_node, bias_s = output_single_node
        output_distributed, bias_d = output_distributed
        if parallel_mode == "column":
            bias_d = _gather(bias_d)
        _check_outputs(bias_s, bias_d)

    # Gather outputs if necessary
    if parallel_mode == "column" or (sequence_parallel and parallel_mode == "row"):
        output_distributed = _gather(output_distributed, dim=1 if parallel_mode == "column" else 0)

    # Compute loss and backpropagate
    _loss_backward(output_single_node, output_distributed)

658
659
660
661
    # Compute delayed weight gradient
    if "delay_wgrad_compute" in kwargs:
        _loss_backward_dw(model_single_node, model_distributed)

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    # Validate outputs and gradients
    _check_outputs(output_single_node, output_distributed)

    # gradients in other cases need additional synchronization
    if parallel_mode == "column" and not sequence_parallel and "return_bias" not in kwargs:
        _check_gradients(
            model_distributed,
            model_single_node,
            main_grad_check=("fuse_wgrad_accumulation" in kwargs),
        )


def test_layernorm_linear():
    kwargs_list = [
        {},
        {"bias": False},
        {"init_method": _constant},
        {"fuse_wgrad_accumulation": True},
        {"return_bias": True},
        {"params_dtype": torch.float16},
        {"zero_centered_gamma": False},
        {"return_layernorm_output": True},
684
        {"delay_wgrad_compute": True},
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    ]
    for kwargs in kwargs_list:
        for parallel_mode in ["column"]:
            for sequence_parallel in [False, True]:
                _test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs)


############################################
#               LayerNormMLP               #
############################################


@run_distributed_test()
def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwargs):
    """Test the LayerNormMLP with specified parallel mode and sequence parallelization.

    Args:
        set_parallel_mode (bool): Enable parallel mode.
        sequence_parallel (bool): Enable sequence parallelism if True.
        kwargs (dict): Additional arguments for the linear layer.
    """
    # Set parameter data type
    params_dtype = kwargs.get("params_dtype", torch.float32)
708
    FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731

    # Create models
    model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs)
    model_distributed = te.LayerNormMLP(
        HIDDEN_SIZE,
        FFN_HIDDEN_SIZE,
        tp_size=WORLD_SIZE,
        tp_group=NCCL_WORLD,
        set_parallel_mode=set_parallel_mode,
        sequence_parallel=sequence_parallel,
        **kwargs,
    )

    # Synchronize parameters between models
    _copy_params(model_distributed, model_single_node)

    # Prepare input tensors
    input_single_node = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)

    if sequence_parallel:
        # Duplicate input for sequence parallelism
        input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
        input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
732
733
734
735
736
737
        # make the last element of the input a large value to test the amax reduction on purpose
        # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
        if QUANTIZATION == "fp8_cs":
            input_distributed = torch.clamp(input_distributed, min=-10, max=10)
            if WORLD_RANK == WORLD_SIZE - 1:
                input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
        input_single_node = _gather(input_distributed).detach()
    else:
        input_distributed = input_single_node.clone()
    # Apply models
    output_single_node, output_distributed = _apply_models(
        model_single_node, model_distributed, input_single_node, input_distributed
    )

    if "return_layernorm_output" in kwargs:
        output_single_node, norm_s = output_single_node
        output_distributed, norm_d = output_distributed
        if sequence_parallel:
            norm_d = _gather(norm_d)
        _check_outputs(norm_s, norm_d)

    if "return_bias" in kwargs:
        output_single_node, bias_s = output_single_node
        output_distributed, bias_d = output_distributed
        _check_outputs(bias_s, bias_d)

    if sequence_parallel:
        output_distributed = _gather(output_distributed)

    # Compute loss and backpropagate
    _loss_backward(output_single_node, output_distributed)

764
765
766
    if "delay_wgrad_compute" in kwargs:
        _loss_backward_dw(model_single_node, model_distributed)

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
    # Validate outputs and gradients
    _check_outputs(output_single_node, output_distributed)

    # gradients in other cases need additional synchronization
    if not sequence_parallel and "return_bias" not in kwargs:
        _check_gradients(
            model_distributed,
            model_single_node,
            main_grad_check=("fuse_wgrad_accumulation" in kwargs),
        )


def test_layernorm_mlp():
    kwargs_list = [
        {},
        {"init_method": _constant},
        {"output_layer_init_method": _constant},
        {"normalization": "RMSNorm"},
        {"zero_centered_gamma": True},
        {"bias": False},
        {"params_dtype": torch.float16},
        {"activation": "relu"},
        {"fuse_wgrad_accumulation": True},
        {"return_bias": True},
        {"return_layernorm_output": True},
792
        {"delay_wgrad_compute": True},
793
    ]
794

795
796
797
798
799
800
801
802
803
804
805
806
807
808
    for kwargs in kwargs_list:
        for set_parallel_mode in [True]:
            for sequence_parallel in [False, True]:
                _test_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)


############################################
#             TransformerLayer             #
############################################


@run_distributed_test()
def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
    params_dtype = kwargs.get("params_dtype", torch.float32)
809
    FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888

    model_single_node = te.TransformerLayer(
        HIDDEN_SIZE, FFN_HIDDEN_SIZE, NR_HEADS, attention_dropout=0, hidden_dropout=0, **kwargs
    )
    model_distributed = te.TransformerLayer(
        HIDDEN_SIZE,
        FFN_HIDDEN_SIZE,
        NR_HEADS,
        tp_size=WORLD_SIZE,
        tp_group=NCCL_WORLD,
        set_parallel_mode=True,
        sequence_parallel=sequence_parallel,
        seq_length=WORLD_SIZE * SEQ_LEN if sequence_parallel else None,
        attention_dropout=0,
        hidden_dropout=0,
        **kwargs,
    )

    _copy_params(model_distributed, model_single_node)
    _alloc_main_grad(model_single_node, model_distributed)  # for fuse_wgrad_accumulation=True

    input_single_node = (
        torch.randn((WORLD_SIZE * SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
    )
    if sequence_parallel:
        input_distributed = input_single_node[
            WORLD_RANK * SEQ_LEN : (WORLD_RANK + 1) * SEQ_LEN, :, :
        ]
    else:
        input_distributed = input_single_node.clone().cuda()

    encoder_output = None
    if "layer_type" in kwargs:
        encoder_output = torch.randn((SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)).cuda()

    output_single_node, output_distributed = _apply_models(
        model_single_node,
        model_distributed,
        input_single_node,
        input_distributed,
        encoder_output=encoder_output,
    )

    if sequence_parallel:
        output_distributed = _gather(output_distributed)

    _loss_backward(output_single_node, output_distributed)
    _check_outputs(output_single_node, output_distributed)

    # gradients in other cases need additional synchronization
    if not sequence_parallel and "return_bias" not in kwargs:
        _check_gradients(
            model_distributed,
            model_single_node,
            main_grad_check=("fuse_wgrad_accumulation" in kwargs),
        )


def test_transformer_layer():
    kwargs_list = [
        {},
        {"num_gqa_groups": 4},
        {"init_method": _constant},
        {"output_layer_init_method": _constant},
        {"apply_residual_connection_post_layernorm": True},
        {"output_layernorm": True},
        {"parallel_attention_mlp": True},
        # {"layer_type": "decoder"},
        {"window_size": (2, 2)},
        {"normalization": "RMSNorm"},
        {"zero_centered_gamma": True},
        {"fuse_qkv_params": True},
        {"fuse_qkv_params": True, "fuse_wgrad_accumulation": True},
        {"qkv_weight_interleaved": False},
        {"bias": False},
        {"params_dtype": torch.float16},
        {"fuse_qkv_params": True},
        {"activation": "relu"},
    ]
889

890
891
892
893
894
895
896
    for kwargs in kwargs_list:
        for sequence_parallel in [False, True]:
            _test_transformer_layer_parallel(sequence_parallel, **kwargs)


if __name__ == "__main__":
    sys.exit(main())