"transformer_engine/pytorch/csrc/common.cpp" did not exist on "a5ba71f3f7379acad9c2292a289aa58ab8a489a8"
test_fusible_ops.py 34.7 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

from __future__ import annotations

import argparse
8
from collections.abc import Iterable
9
10
11
12
13
14
import functools
import itertools
import os
import pathlib
import subprocess
import sys
15
from typing import Optional
16
17
18
19
20

import pytest
import torch

import transformer_engine
21
import transformer_engine.common.recipe
22
import transformer_engine.pytorch as te
23
24
from transformer_engine.pytorch import (
    QuantizedTensor,
25
26
    Float8Quantizer,
    Float8CurrentScalingQuantizer,
27
28
29
    MXFP8Quantizer,
    NVFP4Quantizer,
    is_bf16_available,
30
)
31
32
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex
yuguo's avatar
yuguo committed
33
from torch.utils.cpp_extension import IS_HIP_EXTENSION
34

35
36
37
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
38
from utils import dtype_tols, make_recipe, quantization_tols
39

40
# Check what quantization schemes are supported
41
42
43
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_mxfp8_available(return_reason=True)
44
45
quantization_list: list[Optional[str]] = [None]
if fp8_available:
46
    quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
47
48
if mxfp8_available:
    quantization_list.append("mxfp8")
49
50
if nvfp4_available:
    quantization_list.append("nvfp4")
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


@functools.cache
def world_group() -> torch.distributed.ProcessGroup:
    """Get NCCL process group, initializing if needed"""
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(rank)
    group = torch.distributed.init_process_group(
        "nccl",
        init_method="file:///tmp/rdzv",
        world_size=world_size,
        rank=rank,
    )
    return group


def reset_rng(seed: int = 1234) -> None:
    """Reset random number generators"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


@torch.no_grad()
def make_reference_and_test_tensors(
    shape: int | Iterable[int],
77
    quantization: Optional[str] = None,
78
79
80
81
    ref_dtype: torch.dtype = torch.float64,
    ref_device: torch.device = "cpu",
    test_dtype: torch.dtype = torch.float32,
    test_device: torch.device = "cuda",
82
    test_is_quantized: bool = False,
83
84
85
86
87
88
89
90
    requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct tensors with the same values

    The reference tensor is intended for use in plain PyTorch
    operations in high precision. The test tensor is intended for use
    in Transformer Engine operations.

91
92
93
    If a quantization scheme is provided, the tensor values are
    quantized so that they are representable.

94
    """
95
96

    # Random reference tensor
97
    ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
98
99

    # Construct test tensor from reference tensor
100
    test = ref.to(device=test_device, dtype=test_dtype)
101
102
103
104
105
106
    if quantization is None:
        if test_is_quantized:
            raise ValueError("Quantization scheme not provided")
        if test.data_ptr() == ref.data_ptr():
            test = test.clone()
    elif quantization in ("fp8", "fp8_delayed_scaling"):
107
        quantizer = Float8Quantizer(
108
            scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
109
110
111
112
            amax=torch.zeros(1, dtype=torch.float32, device=test_device),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        test = quantizer(test)
113
114
115
116
117
118
119
120
    elif quantization == "fp8_current_scaling":
        quantizer = Float8CurrentScalingQuantizer(
            fp8_dtype=tex.DType.kFloat8E4M3,
            device=test_device,
        )
        test = quantizer(test)
    elif quantization == "mxfp8":
        test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
121
122
123
124
125
126
127
128
    elif quantization == "nvfp4":
        test = NVFP4Quantizer(
            with_rht=False,
            with_post_rht_amax=False,
            with_2d_quantization=False,
            stochastic_rounding=False,
            with_random_sign_mask=False,
        )(test)
129
130
131
132
133
134
    else:
        raise ValueError(f"Unsupported quantization scheme ({quantization})")
    if isinstance(test, QuantizedTensor) and not test_is_quantized:
        test = test.dequantize()

    # Make sure reference and test tensors match each other
135
    ref.copy_(test)
136

137
138
139
140
141
142
143
    ref.requires_grad_(requires_grad)
    test.requires_grad_(requires_grad)
    return ref, test


def _test_all_reduce(
    *,
144
    local_size: int = 32,
145
146
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
147
    quantization: Optional[str] = None,
148
149
150
151
152
153
154
155
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
156
157
    in_shape = [world_size, local_size, local_size]
    out_shape = [local_size, local_size]
158
159
160

    # Random data
    reset_rng()
161
    with_quantization = quantization is not None
162
163
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
164
        quantization=quantization,
165
166
        test_dtype=dtype,
        test_device=device,
167
        test_is_quantized=with_quantization,
168
169
170
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
171
        quantization=quantization,
172
173
        test_dtype=dtype,
        test_device=device,
174
        test_is_quantized=with_quantization,
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    )

    # Plain PyTorch implementation
    y_ref = x_ref.sum(0)
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dx_ref = x_ref.grad[rank]
        x_ref = x_ref[rank]
        x_test = x_test[rank].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    op = te_ops.AllReduce(process_group=process_group)
    y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
    torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)


def _test_all_gather(
    *,
202
    local_size: int = 32,
203
204
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
205
    quantization: Optional[str] = None,
206
207
208
209
210
211
212
213
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
214
215
    in_shape = [world_size, local_size, local_size]
    out_shape = [world_size, world_size * local_size, local_size]
216
217
218

    # Random data
    reset_rng()
219
    with_quantization = quantization is not None
220
221
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
222
        quantization=quantization,
223
224
        test_dtype=dtype,
        test_device=device,
225
        test_is_quantized=with_quantization,
226
227
228
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
229
        quantization=quantization,
230
231
        test_dtype=dtype,
        test_device=device,
232
        test_is_quantized=with_quantization,
233
234
235
    )

    # Plain PyTorch implementation
236
    y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dx_ref = x_ref.grad[rank]
        x_ref = x_ref[rank]
        x_test = x_test[rank].clone()
        y_ref = y_ref[rank]
        dy_ref = dy_ref[rank]
        dy_test = dy_test[rank].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    op = te_ops.AllGather(process_group=process_group)
    y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)
    torch.testing.assert_close(dx_test, dx_ref, **dtype_tols(dtype))


def _test_reduce_scatter(
    *,
263
    local_size: int = 32,
264
265
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
266
    quantization: Optional[str] = None,
267
268
269
270
271
272
273
274
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
275
276
    in_shape = [world_size, world_size * local_size, local_size]
    out_shape = [world_size, local_size, local_size]
277
278
279

    # Random data
    reset_rng()
280
    with_quantization = quantization is not None
281
282
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
283
        quantization=quantization,
284
285
        test_dtype=dtype,
        test_device=device,
286
        test_is_quantized=with_quantization,
287
288
289
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
290
        quantization=quantization,
291
292
        test_dtype=dtype,
        test_device=device,
293
        test_is_quantized=with_quantization,
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    )

    # Plain PyTorch implementation
    y_ref = x_ref.sum(0).reshape(out_shape)
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dx_ref = x_ref.grad[rank]
        x_ref = x_ref[rank]
        x_test = x_test[rank].clone()
        y_ref = y_ref[rank]
        dy_ref = dy_ref[rank]
        dy_test = dy_test[rank].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    op = te_ops.ReduceScatter(process_group=process_group)
    y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
    torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)


def _test_basic_linear(
    *,
324
325
    local_weight_shape: tuple[int, int] = (32, 32),
    local_batch_size: int = 32,
326
327
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
328
329
    quantization: Optional[str] = None,
    quantized_weight: bool = False,
330
331
332
    tensor_parallel_mode: str = "column",
    sequence_parallel: bool = False,
) -> None:
333
334

    # Skip invalid configurations
335
    quantized_compute = quantization is not None
336
337
    if not quantized_compute and quantized_weight:
        return
338
339
340
341
342
343
344
345
346

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    local_out_features, local_in_features = local_weight_shape
    out_features, in_features = local_out_features, local_in_features
347
    batch_size = local_batch_size
348
349
350
351
    if tensor_parallel_mode == "column":
        out_features *= world_size
    elif tensor_parallel_mode == "row":
        in_features *= world_size
352
353
    if sequence_parallel:
        batch_size *= world_size
354
355
356
357
358
359
360
    in_shape = [batch_size, in_features]
    out_shape = [batch_size, out_features]

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
361
        quantization=quantization,
362
363
364
365
366
        test_dtype=dtype,
        test_device=device,
    )
    w_ref, w_test = make_reference_and_test_tensors(
        (out_features, in_features),
367
        quantization=quantization,
368
369
370
371
372
        test_dtype=dtype,
        test_device=device,
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
373
        quantization=quantization,
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    # Plain PyTorch implementation
    y_ref = torch.nn.functional.linear(x_ref, w_ref)
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dw_ref = w_ref.grad
        dx_ref = x_ref.grad
        if tensor_parallel_mode == "column":
            local_out_features = out_features // world_size
            local_slice = slice(
                rank * local_out_features,
                (rank + 1) * local_out_features,
            )
            w_ref = w_ref[local_slice, :]
            dw_ref = dw_ref[local_slice, :]
            w_test = w_test[local_slice, :]
            y_ref = y_ref[..., local_slice]
            dy_ref = dy_ref[..., local_slice]
            dy_test = dy_test[..., local_slice].clone()
        elif tensor_parallel_mode == "row":
            local_in_features = in_features // world_size
            local_slice = slice(
                rank * local_in_features,
                (rank + 1) * local_in_features,
            )
            w_ref = w_ref[:, local_slice]
            dw_ref = dw_ref[:, local_slice]
            w_test = w_test[:, local_slice]
            x_ref = x_ref[..., local_slice]
            dx_ref = dx_ref[..., local_slice]
            x_test = x_test[..., local_slice].clone()
        if sequence_parallel:
            local_batch_size = batch_size // world_size
            local_slice = slice(
                rank * local_batch_size,
                (rank + 1) * local_batch_size,
            )
            if tensor_parallel_mode == "column":
                x_ref = x_ref[local_slice, ...]
                dx_ref = dx_ref[local_slice, ...]
                x_test = x_test[local_slice, ...].clone()
            elif tensor_parallel_mode == "row":
                y_ref = y_ref[local_slice, ...]
                dy_ref = dy_ref[local_slice, ...]
                dy_test = dy_test[local_slice, ...].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
428
    recipe = make_recipe(quantization)
429
    with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
430
431
432
433
434
435
436
437
438
439
440
441
        op = te_ops.BasicLinear(
            in_features,
            out_features,
            device=device,
            dtype=dtype,
            tensor_parallel_mode=tensor_parallel_mode,
            tensor_parallel_group=process_group,
            sequence_parallel=sequence_parallel,
        )
    with torch.no_grad():
        op.weight.copy_(w_test)
        del w_test
442
    with te.autocast(enabled=quantized_compute, recipe=recipe):
443
444
445
446
447
448
449
        y_test = op(x_test)
    y_test.backward(dy_test)

    # Expected numerical error
    tols = dtype_tols(dtype)
    if dtype == torch.float32:
        tols = dtype_tols(torch.float16)  # TF32 GEMM
450
    if quantized_compute:
451
        tols = quantization_tols(quantization)
452
453
454
455
456
457
458
459
460
461
462
463
464

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **tols)
    torch.testing.assert_close(dx_test, dx_ref, **tols)
    torch.testing.assert_close(dw_test, dw_ref, **tols)


def _test_linear(
    *,
    bias: bool = True,
465
466
    local_weight_shape: tuple[int, int] = (32, 32),
    local_batch_size: int = 32,
467
468
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
469
470
    quantization: Optional[str] = None,
    quantized_weight: bool = False,
471
472
473
    tensor_parallel_mode: str = "column",
    sequence_parallel: bool = False,
) -> None:
474
475

    # Skip invalid configurations
476
    quantized_compute = quantization is not None
477
478
    if not quantized_compute and quantized_weight:
        return
479
480
481
482
483
484
485
486
487

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    local_out_features, local_in_features = local_weight_shape
    out_features, in_features = local_out_features, local_in_features
488
    batch_size = local_batch_size
489
490
491
492
    if tensor_parallel_mode == "column":
        out_features *= world_size
    elif tensor_parallel_mode == "row":
        in_features *= world_size
493
494
    if sequence_parallel:
        batch_size *= world_size
495
496
497
498
499
500
501
    in_shape = [batch_size, in_features]
    out_shape = [batch_size, out_features]

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
502
        quantization=quantization,
503
504
505
506
507
        test_dtype=dtype,
        test_device=device,
    )
    w_ref, w_test = make_reference_and_test_tensors(
        (out_features, in_features),
508
        quantization=quantization,
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        test_dtype=dtype,
        test_device=device,
    )
    b_ref, b_test = None, None
    if bias:
        if tensor_parallel_mode == "row":
            bias_shape = [world_size, out_features]
        else:
            bias_shape = [out_features]
        b_ref, b_test = make_reference_and_test_tensors(
            bias_shape,
            test_dtype=dtype,
            test_device=device,
        )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
525
        quantization=quantization,
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    # Plain PyTorch implementation
    y_ref = torch.nn.functional.linear(x_ref, w_ref)
    if bias:
        if tensor_parallel_mode == "row":
            y_ref += b_ref.sum(dim=0)
        else:
            y_ref += b_ref
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dw_ref = w_ref.grad
        db_ref = b_ref.grad if bias else None
        dx_ref = x_ref.grad
        if tensor_parallel_mode == "column":
            local_out_features = out_features // world_size
            local_slice = slice(
                rank * local_out_features,
                (rank + 1) * local_out_features,
            )
            w_ref = w_ref[local_slice, :]
            dw_ref = dw_ref[local_slice, :]
            w_test = w_test[local_slice, :]
            if bias:
                b_ref = b_ref[local_slice]
                db_ref = db_ref[local_slice]
                b_test = b_test[local_slice]
            y_ref = y_ref[..., local_slice]
            dy_ref = dy_ref[..., local_slice]
            dy_test = dy_test[..., local_slice].clone()
        elif tensor_parallel_mode == "row":
            local_in_features = in_features // world_size
            local_slice = slice(
                rank * local_in_features,
                (rank + 1) * local_in_features,
            )
            w_ref = w_ref[:, local_slice]
            dw_ref = dw_ref[:, local_slice]
            w_test = w_test[:, local_slice]
            if bias:
                b_ref = b_ref[rank, :]
                db_ref = db_ref[rank, :]
                b_test = b_test[rank, :]
            x_ref = x_ref[..., local_slice]
            dx_ref = dx_ref[..., local_slice]
            x_test = x_test[..., local_slice].clone()
        if sequence_parallel:
            local_batch_size = batch_size // world_size
            local_slice = slice(
                rank * local_batch_size,
                (rank + 1) * local_batch_size,
            )
            if tensor_parallel_mode == "column":
                x_ref = x_ref[local_slice, ...]
                dx_ref = dx_ref[local_slice, ...]
                x_test = x_test[local_slice, ...].clone()
            elif tensor_parallel_mode == "row":
                y_ref = y_ref[local_slice, ...]
                dy_ref = dy_ref[local_slice, ...]
                dy_test = dy_test[local_slice, ...].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
594
    recipe = make_recipe(quantization)
595
    with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        model = te_ops.Sequential(
            te_ops.Linear(
                in_features,
                out_features,
                bias=bias,
                device=device,
                dtype=dtype,
                tensor_parallel_mode=tensor_parallel_mode,
                tensor_parallel_group=process_group,
                sequence_parallel=sequence_parallel,
            ),
        )
    with torch.no_grad():
        model[0].weight.copy_(w_test)
        if bias:
            model[0].bias.copy_(b_test)
        del w_test
        del b_test
614
    with te.autocast(enabled=quantized_compute, recipe=recipe):
615
616
617
618
619
620
621
        y_test = model(x_test)
    y_test.backward(dy_test)

    # Expected numerical error
    tols = dtype_tols(dtype)
    if dtype == torch.float32:
        tols = dtype_tols(torch.float16)  # TF32 GEMM
622
    if quantized_compute:
623
        tols = quantization_tols(quantization)
624
625
626
627
628
629
630
631
632
633
634
635
636

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **tols)
    torch.testing.assert_close(dx_test, dx_ref, **tols)
    torch.testing.assert_close(dw_test, dw_ref, **tols)
    if bias:
        db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(db_test, db_ref, **tols)


637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
def _test_mlp(
    *,
    bias: bool = True,
    hidden_size: int = 32,
    local_batch_size: int = 32,
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
    quantization: Optional[str] = None,
    quantized_weight: bool = False,
    sequence_parallel: bool = False,
) -> None:
    """2-layer MLP

    MLP includes GELU activation in order to test op fusions. Model
    performs warmup steps in order to test inter-step logic.

    """

    # Skip invalid configurations
    quantized_compute = quantization is not None
    if not quantized_compute and quantized_weight:
        return

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    mlp_size = hidden_size * world_size
    batch_size = local_batch_size
    if sequence_parallel:
        batch_size *= world_size
    in_shape = (batch_size, hidden_size)

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
    )
    w1_ref, w1_test = make_reference_and_test_tensors(
        (mlp_size, hidden_size),
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
    )
    b1_ref, b1_test = None, None
    w2_ref, w2_test = make_reference_and_test_tensors(
        (hidden_size, mlp_size),
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
    )
    b2_ref, b2_test = None, None
    if bias:
        b1_ref, b1_test = make_reference_and_test_tensors(
            (mlp_size,),
            test_dtype=dtype,
            test_device=device,
        )
        b2_ref, b2_test = make_reference_and_test_tensors(
            (world_size, hidden_size),
            test_dtype=dtype,
            test_device=device,
        )
    dy_ref, dy_test = make_reference_and_test_tensors(
        in_shape,
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    # Plain PyTorch implementation
    y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
    y_ref = torch.nn.functional.linear(y_ref, w1_ref)
    if bias:
        y_ref += b1_ref
    y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
    y_ref = torch.nn.functional.linear(y_ref, w2_ref)
    if bias:
        y_ref += b2_ref.sum(dim=0)
    y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        local_mlp_size = mlp_size // world_size
        local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size)
        dx_ref = x_ref.grad
        dw1_ref = w1_ref.grad[local_mlp_slice, :]
        w1_ref = w1_ref[local_mlp_slice, :]
        w1_test = w1_test[local_mlp_slice, :]
        dw2_ref = w2_ref.grad[:, local_mlp_slice]
        w2_ref = w2_ref[:, local_mlp_slice]
        w2_test = w2_test[:, local_mlp_slice]
        if bias:
            db1_ref = b1_ref.grad[local_mlp_slice]
            b1_ref = b1_ref[local_mlp_slice]
            b1_test = b1_test[local_mlp_slice]
            db2_ref = b2_ref.grad[rank, :]
            b2_ref = b2_ref[rank, :]
            b2_test = b2_test[rank, :]
        else:
            db1_ref = None
            db2_ref = None
        if sequence_parallel:
            local_batch_slice = slice(
                rank * local_batch_size,
                (rank + 1) * local_batch_size,
            )
            x_ref = x_ref[local_batch_slice, ...]
            dx_ref = dx_ref[local_batch_slice, ...]
            x_test = x_test[local_batch_slice, ...].clone()
            y_ref = y_ref[local_batch_slice, ...]
            dy_ref = dy_ref[local_batch_slice, ...]
            dy_test = dy_test[local_batch_slice, ...].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    recipe = make_recipe(quantization)
761
    with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
        model = te_ops.Sequential(
            te_ops.GELU(),
            te_ops.Linear(
                hidden_size,
                mlp_size,
                bias=bias,
                device=device,
                dtype=dtype,
                tensor_parallel_mode="column",
                tensor_parallel_group=process_group,
                sequence_parallel=sequence_parallel,
            ),
            te_ops.GELU(),
            te_ops.Linear(
                mlp_size,
                hidden_size,
                bias=bias,
                device=device,
                dtype=dtype,
                tensor_parallel_mode="row",
                tensor_parallel_group=process_group,
                sequence_parallel=sequence_parallel,
            ),
            te_ops.GELU(),
        )
    with torch.no_grad():
        model[1].weight.copy_(w1_test)
        model[3].weight.copy_(w2_test)
        if bias:
            model[1].bias.copy_(b1_test)
            model[3].bias.copy_(b2_test)
        del w1_test, w2_test, b1_test, b2_test

    # Warmup steps
    for _ in range(3):
797
        with te.autocast(enabled=quantized_compute, recipe=recipe):
798
799
800
801
802
803
804
805
806
807
            y_test = model(x_test)
        y_test.backward(dy_test)
    x_test.grad = None
    model[1].weight.grad = None
    model[3].weight.grad = None
    if bias:
        model[1].bias.grad = None
        model[3].bias.grad = None

    # Forward and backward step
808
    with te.autocast(enabled=quantized_compute, recipe=recipe):
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
        y_test = model(x_test)
    y_test.backward(dy_test)

    # Expected numerical error
    tols = dtype_tols(dtype)
    if dtype == torch.float32:
        tols = dtype_tols(torch.float16)  # TF32 GEMM
    if quantized_compute:
        tols = quantization_tols(quantization)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
    dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **tols)
    torch.testing.assert_close(dx_test, dx_ref, **tols)
    torch.testing.assert_close(dw1_test, dw1_ref, **tols)
    torch.testing.assert_close(dw2_test, dw2_ref, **tols)
    if bias:
        db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
        db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(db1_test, db1_ref, **tols)
        torch.testing.assert_close(db2_test, db2_ref, **tols)


835
836
837
838
839
def _test_fp8_scale_update(
    *,
    amax_history_len: int = 31,
    amax_compute_algo: str = "max",
    margin: float = 2,
840
841
    local_weight_shape: tuple[int, int] = (32, 32),
    batch_size: int = 32,
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
    tensor_parallel_mode: str = "column",
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    local_out_features, local_in_features = local_weight_shape
    out_features, in_features = local_out_features, local_in_features
    if tensor_parallel_mode == "column":
        out_features *= world_size
    elif tensor_parallel_mode == "row":
        in_features *= world_size
    in_shape = [batch_size, in_features]
    out_shape = [batch_size, out_features]

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
        test_dtype=dtype,
        test_device=device,
    )
    w_ref, w_test = make_reference_and_test_tensors(
        (out_features, in_features),
        test_dtype=dtype,
        test_device=device,
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    def ref_amax_and_scale(
        ref: torch.Tensor,
        stage: str,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Expected absmax and FP8 scale"""
        amax = ref.abs().amax()
        max_val = {
wenjh's avatar
wenjh committed
888
            "forward": 448.0,
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
            "backward": 57344.0,
        }[stage]
        scale = (max_val / amax) / (2**margin)
        amax = amax.to(dtype=torch.float32, device="cpu")
        scale = scale.to(dtype=torch.float32, device="cpu")
        return amax, scale

    # Compute expected amaxes and FP8 scales
    x_amax_ref, x_scale_ref = ref_amax_and_scale(x_ref, "forward")
    w_amax_ref, w_scale_ref = ref_amax_and_scale(w_ref, "forward")
    dy_amax_ref, dy_scale_ref = ref_amax_and_scale(dy_ref, "backward")

    # Convert to distributed tensors
    with torch.no_grad():
        if tensor_parallel_mode == "column":
            local_out_features = out_features // world_size
            local_slice = slice(
                rank * local_out_features,
                (rank + 1) * local_out_features,
            )
            w_ref = w_ref[local_slice, :]
            w_test = w_test[local_slice, :]
            dy_ref = dy_ref[..., local_slice]
            dy_test = dy_test[..., local_slice].clone()
        elif tensor_parallel_mode == "row":
            local_in_features = in_features // world_size
            local_slice = slice(
                rank * local_in_features,
                (rank + 1) * local_in_features,
            )
            w_ref = w_ref[:, local_slice]
            w_test = w_test[:, local_slice]
            x_ref = x_ref[..., local_slice]
            x_test = x_test[..., local_slice].clone()
    x_test.requires_grad_()

    # Initialize fusible operation
    op = te_ops.BasicLinear(
        in_features,
        out_features,
        device=device,
        dtype=dtype,
        tensor_parallel_mode=tensor_parallel_mode,
        tensor_parallel_group=process_group,
    )
    with torch.no_grad():
        op.weight.copy_(w_test)
        del w_test

    # Forward and backward pass
    fp8_format = transformer_engine.common.recipe.Format.HYBRID
    recipe = transformer_engine.common.recipe.DelayedScaling(
        margin=margin,
        fp8_format=fp8_format,
        amax_history_len=amax_history_len,
        amax_compute_algo=amax_compute_algo,
    )
946
    with te.autocast(recipe=recipe):
947
948
949
950
        y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
951
952
953
954
955
956
    x_quantizer = op.get_quantizer("forward", 0)
    w_quantizer = op.get_quantizer("forward", 1)
    dy_quantizer = op.get_quantizer("backward", 0)
    x_scale_test = x_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
    w_scale_test = w_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
    dy_scale_test = dy_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
    torch.testing.assert_close(x_scale_test, x_scale_ref)
    torch.testing.assert_close(w_scale_test, w_scale_ref)
    torch.testing.assert_close(dy_scale_test, dy_scale_ref)


def run_parallel_tests() -> None:
    """Run parallel tests"""

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Collective communication ops
    if rank == 0:
        print(f"Running _test_all_reduce")
    _test_all_reduce()
974
975
976
977
    for quantization in quantization_list:
        if rank == 0:
            print(f"Running _test_all_gather with quantization={quantization}")
        _test_all_gather(quantization=quantization)
978
979
980
981
982
983
    if rank == 0:
        print(f"Running _test_reduce_scatter")
    _test_reduce_scatter()

    # Basic linear op
    for config in itertools.product(
984
        quantization_list,
985
986
987
988
989
        ("column", "row"),
        (False, True),
    ):
        if rank == 0:
            print(f"Running _test_basic_linear with {config=}")
990
        quantization, tensor_parallel_mode, sequence_parallel = config
991
        _test_basic_linear(
992
            quantization=quantization,
993
994
995
996
997
998
            tensor_parallel_mode=tensor_parallel_mode,
            sequence_parallel=sequence_parallel,
        )

    # Linear op
    for config in itertools.product(
999
        quantization_list,
1000
        ("column", "row"),
1001
        (False, True),
1002
1003
1004
    ):
        if rank == 0:
            print(f"Running _test_linear with {config=}")
1005
        quantization, tensor_parallel_mode, sequence_parallel = config
1006
        dtype = torch.bfloat16 if is_bf16_available() else torch.float32
1007
1008
1009
        _test_linear(
            bias=True,  # bias=False is tested in _test_basic_linear
            dtype=dtype,
1010
            quantization=quantization,
1011
            tensor_parallel_mode=tensor_parallel_mode,
1012
1013
1014
1015
1016
1017
1018
1019
            sequence_parallel=sequence_parallel,
        )

    # MLP
    for config in itertools.product(quantization_list, (False, True)):
        if rank == 0:
            print(f"Running _test_mlp with {config=}")
        quantization, sequence_parallel = config
1020
        dtype = torch.bfloat16 if is_bf16_available() else torch.float32
1021
1022
1023
1024
1025
        _test_mlp(
            bias=True,  # bias=False is tested in _test_basic_linear
            dtype=dtype,
            quantization=quantization,
            sequence_parallel=sequence_parallel,
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        )

    # FP8 scale update
    if fp8_available:
        if rank == 0:
            print(f"Running _test_fp8_scale_update")
        _test_fp8_scale_update()


# Parallel job sizes
_world_sizes = [torch.cuda.device_count()]
if 1 not in _world_sizes:
    _world_sizes.append(1)
if torch.cuda.device_count() >= 2 and 2 not in _world_sizes:
    _world_sizes.append(2)


@pytest.mark.parametrize("world_size", _world_sizes)
def test_distributed_fuser_ops(world_size: int) -> None:
    """Launch parallel job that runs parallel tests"""
    python_exe = pathlib.Path(sys.executable).resolve()
    current_file = pathlib.Path(__file__).resolve()
    command = [
        python_exe,
        "-m",
        "torch.distributed.run",
        f"--nproc_per_node={world_size}",
        current_file,
        "--parallel",
    ]
    result = subprocess.run(
        command,
        check=True,
    )


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
    args = parser.parse_args()
    if args.parallel:
        run_parallel_tests()


if __name__ == "__main__":
    main()