test_fusible_ops.py 35 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

from __future__ import annotations

import argparse
8
from collections.abc import Iterable
9
10
11
12
13
14
import functools
import itertools
import os
import pathlib
import subprocess
import sys
15
from typing import Optional
16
17
18
19
20

import pytest
import torch

import transformer_engine
21
import transformer_engine.common.recipe
22
23
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
24
from transformer_engine.pytorch.tensor import QuantizedTensor
25
26
27
28
29
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8Quantizer,
    Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
30
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
31
32
33
34
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex

35
36
37
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
38
from utils import dtype_tols, make_recipe, quantization_tols
39

40
41

# Check what quantization schemes are supported
42
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
43
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
44
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available()
45
46
quantization_list: list[Optional[str]] = [None]
if fp8_available:
47
    quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
48
49
if mxfp8_available:
    quantization_list.append("mxfp8")
50
51
if nvfp4_available:
    quantization_list.append("nvfp4")
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77


@functools.cache
def world_group() -> torch.distributed.ProcessGroup:
    """Get NCCL process group, initializing if needed"""
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(rank)
    group = torch.distributed.init_process_group(
        "nccl",
        init_method="file:///tmp/rdzv",
        world_size=world_size,
        rank=rank,
    )
    return group


def reset_rng(seed: int = 1234) -> None:
    """Reset random number generators"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


@torch.no_grad()
def make_reference_and_test_tensors(
    shape: int | Iterable[int],
78
    quantization: Optional[str] = None,
79
80
81
82
    ref_dtype: torch.dtype = torch.float64,
    ref_device: torch.device = "cpu",
    test_dtype: torch.dtype = torch.float32,
    test_device: torch.device = "cuda",
83
    test_is_quantized: bool = False,
84
85
86
87
88
89
90
91
    requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct tensors with the same values

    The reference tensor is intended for use in plain PyTorch
    operations in high precision. The test tensor is intended for use
    in Transformer Engine operations.

92
93
94
    If a quantization scheme is provided, the tensor values are
    quantized so that they are representable.

95
    """
96
97

    # Random reference tensor
98
    ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
99
100

    # Construct test tensor from reference tensor
101
    test = ref.to(device=test_device, dtype=test_dtype)
102
103
104
105
106
107
    if quantization is None:
        if test_is_quantized:
            raise ValueError("Quantization scheme not provided")
        if test.data_ptr() == ref.data_ptr():
            test = test.clone()
    elif quantization in ("fp8", "fp8_delayed_scaling"):
108
        quantizer = Float8Quantizer(
109
            scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
110
111
112
113
            amax=torch.zeros(1, dtype=torch.float32, device=test_device),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        test = quantizer(test)
114
115
116
117
118
119
120
121
    elif quantization == "fp8_current_scaling":
        quantizer = Float8CurrentScalingQuantizer(
            fp8_dtype=tex.DType.kFloat8E4M3,
            device=test_device,
        )
        test = quantizer(test)
    elif quantization == "mxfp8":
        test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
122
123
124
125
126
127
128
129
    elif quantization == "nvfp4":
        test = NVFP4Quantizer(
            with_rht=False,
            with_post_rht_amax=False,
            with_2d_quantization=False,
            stochastic_rounding=False,
            with_random_sign_mask=False,
        )(test)
130
131
132
133
134
135
    else:
        raise ValueError(f"Unsupported quantization scheme ({quantization})")
    if isinstance(test, QuantizedTensor) and not test_is_quantized:
        test = test.dequantize()

    # Make sure reference and test tensors match each other
136
    ref.copy_(test)
137

138
139
140
141
142
143
144
    ref.requires_grad_(requires_grad)
    test.requires_grad_(requires_grad)
    return ref, test


def _test_all_reduce(
    *,
145
    local_size: int = 32,
146
147
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
148
    quantization: Optional[str] = None,
149
150
151
152
153
154
155
156
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
157
158
    in_shape = [world_size, local_size, local_size]
    out_shape = [local_size, local_size]
159
160
161

    # Random data
    reset_rng()
162
    with_quantization = quantization is not None
163
164
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
165
        quantization=quantization,
166
167
        test_dtype=dtype,
        test_device=device,
168
        test_is_quantized=with_quantization,
169
170
171
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
172
        quantization=quantization,
173
174
        test_dtype=dtype,
        test_device=device,
175
        test_is_quantized=with_quantization,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    )

    # Plain PyTorch implementation
    y_ref = x_ref.sum(0)
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dx_ref = x_ref.grad[rank]
        x_ref = x_ref[rank]
        x_test = x_test[rank].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    op = te_ops.AllReduce(process_group=process_group)
    y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
    torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)


def _test_all_gather(
    *,
203
    local_size: int = 32,
204
205
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
206
    quantization: Optional[str] = None,
207
208
209
210
211
212
213
214
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
215
216
    in_shape = [world_size, local_size, local_size]
    out_shape = [world_size, world_size * local_size, local_size]
217
218
219

    # Random data
    reset_rng()
220
    with_quantization = quantization is not None
221
222
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
223
        quantization=quantization,
224
225
        test_dtype=dtype,
        test_device=device,
226
        test_is_quantized=with_quantization,
227
228
229
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
230
        quantization=quantization,
231
232
        test_dtype=dtype,
        test_device=device,
233
        test_is_quantized=with_quantization,
234
235
236
    )

    # Plain PyTorch implementation
237
    y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dx_ref = x_ref.grad[rank]
        x_ref = x_ref[rank]
        x_test = x_test[rank].clone()
        y_ref = y_ref[rank]
        dy_ref = dy_ref[rank]
        dy_test = dy_test[rank].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    op = te_ops.AllGather(process_group=process_group)
    y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)
    torch.testing.assert_close(dx_test, dx_ref, **dtype_tols(dtype))


def _test_reduce_scatter(
    *,
264
    local_size: int = 32,
265
266
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
267
    quantization: Optional[str] = None,
268
269
270
271
272
273
274
275
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
276
277
    in_shape = [world_size, world_size * local_size, local_size]
    out_shape = [world_size, local_size, local_size]
278
279
280

    # Random data
    reset_rng()
281
    with_quantization = quantization is not None
282
283
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
284
        quantization=quantization,
285
286
        test_dtype=dtype,
        test_device=device,
287
        test_is_quantized=with_quantization,
288
289
290
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
291
        quantization=quantization,
292
293
        test_dtype=dtype,
        test_device=device,
294
        test_is_quantized=with_quantization,
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    )

    # Plain PyTorch implementation
    y_ref = x_ref.sum(0).reshape(out_shape)
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dx_ref = x_ref.grad[rank]
        x_ref = x_ref[rank]
        x_test = x_test[rank].clone()
        y_ref = y_ref[rank]
        dy_ref = dy_ref[rank]
        dy_test = dy_test[rank].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    op = te_ops.ReduceScatter(process_group=process_group)
    y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
    torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)


def _test_basic_linear(
    *,
325
326
    local_weight_shape: tuple[int, int] = (32, 32),
    local_batch_size: int = 32,
327
328
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
329
330
    quantization: Optional[str] = None,
    quantized_weight: bool = False,
331
332
333
    tensor_parallel_mode: str = "column",
    sequence_parallel: bool = False,
) -> None:
334
335

    # Skip invalid configurations
336
    quantized_compute = quantization is not None
337
338
    if not quantized_compute and quantized_weight:
        return
339
340
341
342
343
344
345
346
347

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    local_out_features, local_in_features = local_weight_shape
    out_features, in_features = local_out_features, local_in_features
348
    batch_size = local_batch_size
349
350
351
352
    if tensor_parallel_mode == "column":
        out_features *= world_size
    elif tensor_parallel_mode == "row":
        in_features *= world_size
353
354
    if sequence_parallel:
        batch_size *= world_size
355
356
357
358
359
360
361
    in_shape = [batch_size, in_features]
    out_shape = [batch_size, out_features]

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
362
        quantization=quantization,
363
364
365
366
367
        test_dtype=dtype,
        test_device=device,
    )
    w_ref, w_test = make_reference_and_test_tensors(
        (out_features, in_features),
368
        quantization=quantization,
369
370
371
372
373
        test_dtype=dtype,
        test_device=device,
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
374
        quantization=quantization,
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    # Plain PyTorch implementation
    y_ref = torch.nn.functional.linear(x_ref, w_ref)
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dw_ref = w_ref.grad
        dx_ref = x_ref.grad
        if tensor_parallel_mode == "column":
            local_out_features = out_features // world_size
            local_slice = slice(
                rank * local_out_features,
                (rank + 1) * local_out_features,
            )
            w_ref = w_ref[local_slice, :]
            dw_ref = dw_ref[local_slice, :]
            w_test = w_test[local_slice, :]
            y_ref = y_ref[..., local_slice]
            dy_ref = dy_ref[..., local_slice]
            dy_test = dy_test[..., local_slice].clone()
        elif tensor_parallel_mode == "row":
            local_in_features = in_features // world_size
            local_slice = slice(
                rank * local_in_features,
                (rank + 1) * local_in_features,
            )
            w_ref = w_ref[:, local_slice]
            dw_ref = dw_ref[:, local_slice]
            w_test = w_test[:, local_slice]
            x_ref = x_ref[..., local_slice]
            dx_ref = dx_ref[..., local_slice]
            x_test = x_test[..., local_slice].clone()
        if sequence_parallel:
            local_batch_size = batch_size // world_size
            local_slice = slice(
                rank * local_batch_size,
                (rank + 1) * local_batch_size,
            )
            if tensor_parallel_mode == "column":
                x_ref = x_ref[local_slice, ...]
                dx_ref = dx_ref[local_slice, ...]
                x_test = x_test[local_slice, ...].clone()
            elif tensor_parallel_mode == "row":
                y_ref = y_ref[local_slice, ...]
                dy_ref = dy_ref[local_slice, ...]
                dy_test = dy_test[local_slice, ...].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
429
430
    recipe = make_recipe(quantization)
    with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
431
432
433
434
435
436
437
438
439
440
441
442
        op = te_ops.BasicLinear(
            in_features,
            out_features,
            device=device,
            dtype=dtype,
            tensor_parallel_mode=tensor_parallel_mode,
            tensor_parallel_group=process_group,
            sequence_parallel=sequence_parallel,
        )
    with torch.no_grad():
        op.weight.copy_(w_test)
        del w_test
443
    with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
444
445
446
447
448
449
450
        y_test = op(x_test)
    y_test.backward(dy_test)

    # Expected numerical error
    tols = dtype_tols(dtype)
    if dtype == torch.float32:
        tols = dtype_tols(torch.float16)  # TF32 GEMM
451
    if quantized_compute:
452
        tols = quantization_tols(quantization)
453
454
455
456
457
458
459
460
461
462
463
464
465

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **tols)
    torch.testing.assert_close(dx_test, dx_ref, **tols)
    torch.testing.assert_close(dw_test, dw_ref, **tols)


def _test_linear(
    *,
    bias: bool = True,
466
467
    local_weight_shape: tuple[int, int] = (32, 32),
    local_batch_size: int = 32,
468
469
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
470
471
    quantization: Optional[str] = None,
    quantized_weight: bool = False,
472
473
474
    tensor_parallel_mode: str = "column",
    sequence_parallel: bool = False,
) -> None:
475
476

    # Skip invalid configurations
477
    quantized_compute = quantization is not None
478
479
    if not quantized_compute and quantized_weight:
        return
480
481
482
483
484
485
486
487
488

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    local_out_features, local_in_features = local_weight_shape
    out_features, in_features = local_out_features, local_in_features
489
    batch_size = local_batch_size
490
491
492
493
    if tensor_parallel_mode == "column":
        out_features *= world_size
    elif tensor_parallel_mode == "row":
        in_features *= world_size
494
495
    if sequence_parallel:
        batch_size *= world_size
496
497
498
499
500
501
502
    in_shape = [batch_size, in_features]
    out_shape = [batch_size, out_features]

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
503
        quantization=quantization,
504
505
506
507
508
        test_dtype=dtype,
        test_device=device,
    )
    w_ref, w_test = make_reference_and_test_tensors(
        (out_features, in_features),
509
        quantization=quantization,
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        test_dtype=dtype,
        test_device=device,
    )
    b_ref, b_test = None, None
    if bias:
        if tensor_parallel_mode == "row":
            bias_shape = [world_size, out_features]
        else:
            bias_shape = [out_features]
        b_ref, b_test = make_reference_and_test_tensors(
            bias_shape,
            test_dtype=dtype,
            test_device=device,
        )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
526
        quantization=quantization,
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    # Plain PyTorch implementation
    y_ref = torch.nn.functional.linear(x_ref, w_ref)
    if bias:
        if tensor_parallel_mode == "row":
            y_ref += b_ref.sum(dim=0)
        else:
            y_ref += b_ref
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        dw_ref = w_ref.grad
        db_ref = b_ref.grad if bias else None
        dx_ref = x_ref.grad
        if tensor_parallel_mode == "column":
            local_out_features = out_features // world_size
            local_slice = slice(
                rank * local_out_features,
                (rank + 1) * local_out_features,
            )
            w_ref = w_ref[local_slice, :]
            dw_ref = dw_ref[local_slice, :]
            w_test = w_test[local_slice, :]
            if bias:
                b_ref = b_ref[local_slice]
                db_ref = db_ref[local_slice]
                b_test = b_test[local_slice]
            y_ref = y_ref[..., local_slice]
            dy_ref = dy_ref[..., local_slice]
            dy_test = dy_test[..., local_slice].clone()
        elif tensor_parallel_mode == "row":
            local_in_features = in_features // world_size
            local_slice = slice(
                rank * local_in_features,
                (rank + 1) * local_in_features,
            )
            w_ref = w_ref[:, local_slice]
            dw_ref = dw_ref[:, local_slice]
            w_test = w_test[:, local_slice]
            if bias:
                b_ref = b_ref[rank, :]
                db_ref = db_ref[rank, :]
                b_test = b_test[rank, :]
            x_ref = x_ref[..., local_slice]
            dx_ref = dx_ref[..., local_slice]
            x_test = x_test[..., local_slice].clone()
        if sequence_parallel:
            local_batch_size = batch_size // world_size
            local_slice = slice(
                rank * local_batch_size,
                (rank + 1) * local_batch_size,
            )
            if tensor_parallel_mode == "column":
                x_ref = x_ref[local_slice, ...]
                dx_ref = dx_ref[local_slice, ...]
                x_test = x_test[local_slice, ...].clone()
            elif tensor_parallel_mode == "row":
                y_ref = y_ref[local_slice, ...]
                dy_ref = dy_ref[local_slice, ...]
                dy_test = dy_test[local_slice, ...].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
595
596
    recipe = make_recipe(quantization)
    with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
        model = te_ops.Sequential(
            te_ops.Linear(
                in_features,
                out_features,
                bias=bias,
                device=device,
                dtype=dtype,
                tensor_parallel_mode=tensor_parallel_mode,
                tensor_parallel_group=process_group,
                sequence_parallel=sequence_parallel,
            ),
        )
    with torch.no_grad():
        model[0].weight.copy_(w_test)
        if bias:
            model[0].bias.copy_(b_test)
        del w_test
        del b_test
615
    with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
616
617
618
619
620
621
622
        y_test = model(x_test)
    y_test.backward(dy_test)

    # Expected numerical error
    tols = dtype_tols(dtype)
    if dtype == torch.float32:
        tols = dtype_tols(torch.float16)  # TF32 GEMM
623
    if quantized_compute:
624
        tols = quantization_tols(quantization)
625
626
627
628
629
630
631
632
633
634
635
636
637

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **tols)
    torch.testing.assert_close(dx_test, dx_ref, **tols)
    torch.testing.assert_close(dw_test, dw_ref, **tols)
    if bias:
        db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(db_test, db_ref, **tols)


638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
def _test_mlp(
    *,
    bias: bool = True,
    hidden_size: int = 32,
    local_batch_size: int = 32,
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
    quantization: Optional[str] = None,
    quantized_weight: bool = False,
    sequence_parallel: bool = False,
) -> None:
    """2-layer MLP

    MLP includes GELU activation in order to test op fusions. Model
    performs warmup steps in order to test inter-step logic.

    """

    # Skip invalid configurations
    quantized_compute = quantization is not None
    if not quantized_compute and quantized_weight:
        return

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    mlp_size = hidden_size * world_size
    batch_size = local_batch_size
    if sequence_parallel:
        batch_size *= world_size
    in_shape = (batch_size, hidden_size)

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
    )
    w1_ref, w1_test = make_reference_and_test_tensors(
        (mlp_size, hidden_size),
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
    )
    b1_ref, b1_test = None, None
    w2_ref, w2_test = make_reference_and_test_tensors(
        (hidden_size, mlp_size),
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
    )
    b2_ref, b2_test = None, None
    if bias:
        b1_ref, b1_test = make_reference_and_test_tensors(
            (mlp_size,),
            test_dtype=dtype,
            test_device=device,
        )
        b2_ref, b2_test = make_reference_and_test_tensors(
            (world_size, hidden_size),
            test_dtype=dtype,
            test_device=device,
        )
    dy_ref, dy_test = make_reference_and_test_tensors(
        in_shape,
        quantization=quantization,
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    # Plain PyTorch implementation
    y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
    y_ref = torch.nn.functional.linear(y_ref, w1_ref)
    if bias:
        y_ref += b1_ref
    y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
    y_ref = torch.nn.functional.linear(y_ref, w2_ref)
    if bias:
        y_ref += b2_ref.sum(dim=0)
    y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
    y_ref.backward(dy_ref)

    # Convert to distributed tensors
    with torch.no_grad():
        local_mlp_size = mlp_size // world_size
        local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size)
        dx_ref = x_ref.grad
        dw1_ref = w1_ref.grad[local_mlp_slice, :]
        w1_ref = w1_ref[local_mlp_slice, :]
        w1_test = w1_test[local_mlp_slice, :]
        dw2_ref = w2_ref.grad[:, local_mlp_slice]
        w2_ref = w2_ref[:, local_mlp_slice]
        w2_test = w2_test[:, local_mlp_slice]
        if bias:
            db1_ref = b1_ref.grad[local_mlp_slice]
            b1_ref = b1_ref[local_mlp_slice]
            b1_test = b1_test[local_mlp_slice]
            db2_ref = b2_ref.grad[rank, :]
            b2_ref = b2_ref[rank, :]
            b2_test = b2_test[rank, :]
        else:
            db1_ref = None
            db2_ref = None
        if sequence_parallel:
            local_batch_slice = slice(
                rank * local_batch_size,
                (rank + 1) * local_batch_size,
            )
            x_ref = x_ref[local_batch_slice, ...]
            dx_ref = dx_ref[local_batch_slice, ...]
            x_test = x_test[local_batch_slice, ...].clone()
            y_ref = y_ref[local_batch_slice, ...]
            dy_ref = dy_ref[local_batch_slice, ...]
            dy_test = dy_test[local_batch_slice, ...].clone()
    x_test.requires_grad_()

    # Implementation with fusible operation
    recipe = make_recipe(quantization)
    with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
        model = te_ops.Sequential(
            te_ops.GELU(),
            te_ops.Linear(
                hidden_size,
                mlp_size,
                bias=bias,
                device=device,
                dtype=dtype,
                tensor_parallel_mode="column",
                tensor_parallel_group=process_group,
                sequence_parallel=sequence_parallel,
            ),
            te_ops.GELU(),
            te_ops.Linear(
                mlp_size,
                hidden_size,
                bias=bias,
                device=device,
                dtype=dtype,
                tensor_parallel_mode="row",
                tensor_parallel_group=process_group,
                sequence_parallel=sequence_parallel,
            ),
            te_ops.GELU(),
        )
    with torch.no_grad():
        model[1].weight.copy_(w1_test)
        model[3].weight.copy_(w2_test)
        if bias:
            model[1].bias.copy_(b1_test)
            model[3].bias.copy_(b2_test)
        del w1_test, w2_test, b1_test, b2_test

    # Warmup steps
    for _ in range(3):
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
            y_test = model(x_test)
        y_test.backward(dy_test)
    x_test.grad = None
    model[1].weight.grad = None
    model[3].weight.grad = None
    if bias:
        model[1].bias.grad = None
        model[3].bias.grad = None

    # Forward and backward step
    with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
        y_test = model(x_test)
    y_test.backward(dy_test)

    # Expected numerical error
    tols = dtype_tols(dtype)
    if dtype == torch.float32:
        tols = dtype_tols(torch.float16)  # TF32 GEMM
    if quantized_compute:
        tols = quantization_tols(quantization)

    # Check results
    y_test = y_test.to(dtype=torch.float64, device="cpu")
    dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
    dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
    dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu")
    torch.testing.assert_close(y_test, y_ref, **tols)
    torch.testing.assert_close(dx_test, dx_ref, **tols)
    torch.testing.assert_close(dw1_test, dw1_ref, **tols)
    torch.testing.assert_close(dw2_test, dw2_ref, **tols)
    if bias:
        db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
        db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(db1_test, db1_ref, **tols)
        torch.testing.assert_close(db2_test, db2_ref, **tols)


836
837
838
839
840
def _test_fp8_scale_update(
    *,
    amax_history_len: int = 31,
    amax_compute_algo: str = "max",
    margin: float = 2,
841
842
    local_weight_shape: tuple[int, int] = (32, 32),
    batch_size: int = 32,
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
    dtype: torch.dtype = torch.float32,
    device: torch.device = "cuda",
    tensor_parallel_mode: str = "column",
) -> None:

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Tensor dimensions
    local_out_features, local_in_features = local_weight_shape
    out_features, in_features = local_out_features, local_in_features
    if tensor_parallel_mode == "column":
        out_features *= world_size
    elif tensor_parallel_mode == "row":
        in_features *= world_size
    in_shape = [batch_size, in_features]
    out_shape = [batch_size, out_features]

    # Random data
    reset_rng()
    x_ref, x_test = make_reference_and_test_tensors(
        in_shape,
        test_dtype=dtype,
        test_device=device,
    )
    w_ref, w_test = make_reference_and_test_tensors(
        (out_features, in_features),
        test_dtype=dtype,
        test_device=device,
    )
    dy_ref, dy_test = make_reference_and_test_tensors(
        out_shape,
        test_dtype=dtype,
        test_device=device,
        requires_grad=False,
    )

    def ref_amax_and_scale(
        ref: torch.Tensor,
        stage: str,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Expected absmax and FP8 scale"""
        amax = ref.abs().amax()
        max_val = {
            "forward": 448.0,
            "backward": 57344.0,
        }[stage]
        scale = (max_val / amax) / (2**margin)
        amax = amax.to(dtype=torch.float32, device="cpu")
        scale = scale.to(dtype=torch.float32, device="cpu")
        return amax, scale

    # Compute expected amaxes and FP8 scales
    x_amax_ref, x_scale_ref = ref_amax_and_scale(x_ref, "forward")
    w_amax_ref, w_scale_ref = ref_amax_and_scale(w_ref, "forward")
    dy_amax_ref, dy_scale_ref = ref_amax_and_scale(dy_ref, "backward")

    # Convert to distributed tensors
    with torch.no_grad():
        if tensor_parallel_mode == "column":
            local_out_features = out_features // world_size
            local_slice = slice(
                rank * local_out_features,
                (rank + 1) * local_out_features,
            )
            w_ref = w_ref[local_slice, :]
            w_test = w_test[local_slice, :]
            dy_ref = dy_ref[..., local_slice]
            dy_test = dy_test[..., local_slice].clone()
        elif tensor_parallel_mode == "row":
            local_in_features = in_features // world_size
            local_slice = slice(
                rank * local_in_features,
                (rank + 1) * local_in_features,
            )
            w_ref = w_ref[:, local_slice]
            w_test = w_test[:, local_slice]
            x_ref = x_ref[..., local_slice]
            x_test = x_test[..., local_slice].clone()
    x_test.requires_grad_()

    # Initialize fusible operation
    op = te_ops.BasicLinear(
        in_features,
        out_features,
        device=device,
        dtype=dtype,
        tensor_parallel_mode=tensor_parallel_mode,
        tensor_parallel_group=process_group,
    )
    with torch.no_grad():
        op.weight.copy_(w_test)
        del w_test

    # Forward and backward pass
    fp8_format = transformer_engine.common.recipe.Format.HYBRID
    recipe = transformer_engine.common.recipe.DelayedScaling(
        margin=margin,
        fp8_format=fp8_format,
        amax_history_len=amax_history_len,
        amax_compute_algo=amax_compute_algo,
    )
    with te.fp8_autocast(fp8_recipe=recipe):
        y_test = op(x_test)
    y_test.backward(dy_test)

    # Check results
952
953
954
955
956
957
    x_quantizer = op.get_quantizer("forward", 0)
    w_quantizer = op.get_quantizer("forward", 1)
    dy_quantizer = op.get_quantizer("backward", 0)
    x_scale_test = x_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
    w_scale_test = w_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
    dy_scale_test = dy_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
    torch.testing.assert_close(x_scale_test, x_scale_ref)
    torch.testing.assert_close(w_scale_test, w_scale_ref)
    torch.testing.assert_close(dy_scale_test, dy_scale_ref)


def run_parallel_tests() -> None:
    """Run parallel tests"""

    # Distributed process group
    process_group = world_group()
    rank = torch.distributed.get_rank(process_group)
    world_size = torch.distributed.get_world_size(process_group)

    # Collective communication ops
    if rank == 0:
        print(f"Running _test_all_reduce")
    _test_all_reduce()
975
976
977
978
    for quantization in quantization_list:
        if rank == 0:
            print(f"Running _test_all_gather with quantization={quantization}")
        _test_all_gather(quantization=quantization)
979
980
981
982
983
984
    if rank == 0:
        print(f"Running _test_reduce_scatter")
    _test_reduce_scatter()

    # Basic linear op
    for config in itertools.product(
985
        quantization_list,
986
987
988
989
990
        ("column", "row"),
        (False, True),
    ):
        if rank == 0:
            print(f"Running _test_basic_linear with {config=}")
991
        quantization, tensor_parallel_mode, sequence_parallel = config
992
        _test_basic_linear(
993
            quantization=quantization,
994
995
996
997
998
999
            tensor_parallel_mode=tensor_parallel_mode,
            sequence_parallel=sequence_parallel,
        )

    # Linear op
    for config in itertools.product(
1000
        quantization_list,
1001
        ("column", "row"),
1002
        (False, True),
1003
1004
1005
    ):
        if rank == 0:
            print(f"Running _test_linear with {config=}")
1006
        quantization, tensor_parallel_mode, sequence_parallel = config
1007
1008
1009
1010
        dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
        _test_linear(
            bias=True,  # bias=False is tested in _test_basic_linear
            dtype=dtype,
1011
            quantization=quantization,
1012
            tensor_parallel_mode=tensor_parallel_mode,
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
            sequence_parallel=sequence_parallel,
        )

    # MLP
    for config in itertools.product(quantization_list, (False, True)):
        if rank == 0:
            print(f"Running _test_mlp with {config=}")
        quantization, sequence_parallel = config
        dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
        _test_mlp(
            bias=True,  # bias=False is tested in _test_basic_linear
            dtype=dtype,
            quantization=quantization,
            sequence_parallel=sequence_parallel,
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        )

    # FP8 scale update
    if fp8_available:
        if rank == 0:
            print(f"Running _test_fp8_scale_update")
        _test_fp8_scale_update()


# Parallel job sizes
_world_sizes = [torch.cuda.device_count()]
if 1 not in _world_sizes:
    _world_sizes.append(1)
if torch.cuda.device_count() >= 2 and 2 not in _world_sizes:
    _world_sizes.append(2)


@pytest.mark.parametrize("world_size", _world_sizes)
def test_distributed_fuser_ops(world_size: int) -> None:
    """Launch parallel job that runs parallel tests"""
    python_exe = pathlib.Path(sys.executable).resolve()
    current_file = pathlib.Path(__file__).resolve()
    command = [
        python_exe,
        "-m",
        "torch.distributed.run",
        f"--nproc_per_node={world_size}",
        current_file,
        "--parallel",
    ]
    result = subprocess.run(
        command,
        check=True,
    )


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
    args = parser.parse_args()
    if args.parallel:
        run_parallel_tests()


if __name__ == "__main__":
    main()