test_fusible_ops.py 65.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

from __future__ import annotations

7
from collections.abc import Iterable
8
import math
9
from typing import Optional
10
11
12
13
14

import pytest
import torch

import transformer_engine
Tim Moon's avatar
Tim Moon committed
15
import transformer_engine.common.recipe
16
17
18
19
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
20
21
from transformer_engine.pytorch.ops.fused import (
    BackwardLinearAdd,
22
    ForwardLinearBiasActivation,
23
    ForwardLinearBiasAdd,
24
)
25
26
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
27
28
29
30
31
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
32
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
33
34
35
36
37
38
39
40
41
42

# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    _dtypes.append(torch.bfloat16)

# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def maybe_skip_quantization(
    quantization: Optional[str],
    *,
    dims: Optional[Iterable[int] | int] = None,
    device: Optional[torch.device | str] = None,
) -> None:

    # Don't skip if there is no quantization
    if quantization is None:
        return

    # Check if quantization scheme is supported
    if quantization == "fp8" and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    if dims is not None:
        if not isinstance(dims, Iterable):
            dims = (dims,)
        if quantization == "fp8":
            if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
                pytest.skip("FP8 GEMMs require dims that are divisible by 16")
        elif quantization == "mxfp8":
            if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
                pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")

    # Check if device is supported
    if device is not None and torch.device(device).type != "cuda":
        pytest.skip("Quantization is only supported on CUDA devices")


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
    """Estimated numerical error for a datatype

    Based on tolerances for torch.testing.assert_close.

    """

    # Transformer Engine dtypes
    if isinstance(dtype, tex.DType):
        if dtype == tex.DType.kFloat8E4M3:
            return dict(rtol=0.125, atol=0.0675)  # epsilon = 0.0625
        if dtype == tex.DType.kFloat8E5M2:
            return dict(rtol=0.25, atol=0.125)  # epsilon = 0.152
        dtype = {
            tex.DType.kByte: torch.uint8,
            tex.DType.kInt32: torch.int32,
            tex.DType.kFloat32: torch.float32,
            tex.DType.kFloat16: torch.half,
            tex.DType.kBFloat16: torch.bfloat16,
        }[dtype]

    # PyTorch dtypes
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float64:
        return dict(rtol=1e-7, atol=1e-7)
    raise ValueError(f"Unsupported dtype ({dtype})")


@torch.no_grad()
def make_reference_and_test_tensors(
    shape: int | Iterable[int],
    ref_dtype: torch.dtype = torch.float64,
    ref_device: torch.device = "cpu",
    test_dtype: torch.dtype = torch.float32,
    test_device: torch.device = "cuda",
    test_is_fp8: bool = False,
    requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct tensors with the same values

    The reference tensor is intended for use in plain PyTorch
    operations in high precision. The test tensor is intended for use
    in Transformer Engine operations.

    """
    ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
126
    test = ref.to(device=test_device, dtype=test_dtype)
127
    if test_is_fp8:
128
129
130
131
132
133
        quantizer = Float8Quantizer(
            scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
            amax=torch.zeros(1, dtype=torch.float32, device=test_device),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        test = quantizer(test)
134
135
    elif test.data_ptr() == ref.data_ptr():
        test = test.clone()
136
137
138
139
140
141
    ref.copy_(test)
    ref.requires_grad_(requires_grad)
    test.requires_grad_(requires_grad)
    return ref, test


142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
    """Make recipe for quantization scheme"""
    if name is None:
        return None
    if name == "fp8":
        return transformer_engine.common.recipe.DelayedScaling(
            fp8_format=transformer_engine.common.recipe.Format.E4M3,
        )
    if name == "mxfp8":
        return transformer_engine.common.recipe.MXFP8BlockScaling(
            fp8_format=transformer_engine.common.recipe.Format.E4M3,
        )
    raise ValueError(f"Unsupported quantization scheme ({name})")


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class TestSequential:
    """Tests for sequential container"""

    def test_modules(self) -> None:
        """Check that list of modules can be manipulated as expected"""

        # Construct sequential container
        modules = [
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
        ]
        model = te_ops.Sequential(*modules)

        # Length
        assert len(model) == len(modules)

        # Iterator
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Index by int
        for i, module in enumerate(modules):
            assert model[i] is module
            assert model[i - len(modules)] is module

        # Index by slice
        model_subset = model[1:-1]
        modules_subset = modules[1:-1]
        assert isinstance(model_subset, te_ops.Sequential)
        for module1, module2 in zip(model_subset, modules_subset):
            assert module1 is module2

        # Set element
        new_module = torch.nn.Identity()
        idx = 1
        modules[idx] = new_module
        model[idx] = new_module
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Delete element
        idx = 1
        del modules[idx]
        del model[idx]
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Append
        new_module = torch.nn.Identity()
        modules.append(new_module)
        model.append(new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Extend
        new_modules = [te_ops.Identity(), te_ops.Identity()]
        modules.extend(new_modules)
        model.extend(new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Insert
        new_module = te_ops.Identity()
        idx = 2
        modules.insert(idx, new_module)
        model.insert(idx, new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Pop
        idx = 2
        assert model.pop(idx) is modules.pop(idx)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Out-of-place add
        new_modules = [torch.nn.Identity(), te_ops.Identity()]
        added_modules = modules + new_modules
        added_model = model + te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2
        for module1, module2 in zip(added_model, added_modules):
            assert module1 is module2

        # In-place add
        new_modules = [te_ops.Identity(), torch.nn.Identity()]
        modules += new_modules
        model += te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

    def test_module_groups(self) -> None:
        """Check that modules are grouped together correctly"""
        model = te_ops.Sequential(
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
        )
        model(torch.zeros(1))
        assert len(model._module_groups) == 6


class TestFuser:
    """Tests for operation fusion infrastructure"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
    def test_fp8_scale_update(
        self,
        size: int = 16,
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
    ):
        """Test FP8 scaling factors with delayed scaling recipe"""

        # FP8 recipe
        margin = 2
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=8,
            amax_compute_algo="max",
        )

        # Construct model
297
        with te.fp8_model_init(recipe=recipe):
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            model = te_ops.basic.BasicLinear(
                size,
                size,
                device=device,
                dtype=dtype,
            )

        # Training steps
        w_vals = [2, 5, 3, 11]
        x_vals = [7, 3, 5]
        dy_vals = [1, 2, 1]
        with torch.no_grad():
            model.weight.fill_(w_vals[0])
        for step in range(3):

            # Data tensors
            x = torch.full(
                (size, size),
                x_vals[step],
                dtype=dtype,
                device=device,
                requires_grad=True,
            )
            dy = torch.full(
                (size, size),
                dy_vals[step],
                dtype=dtype,
                device=device,
            )

            # Training step
            with te.fp8_autocast(fp8_recipe=recipe):
                y = model(x)
            y.backward(dy)
            with torch.no_grad():
                model.weight.fill_(w_vals[step + 1])

            # Check that output tensors match expected
            tols = dict(rtol=0, atol=0)
            y_val_ref = w_vals[step] * x_vals[step] * size
            dx_val_ref = w_vals[step] * dy_vals[step] * size
            torch.testing.assert_close(
                y,
                torch.full_like(y, y_val_ref),
                **dtype_tols(tex.DType.kFloat8E4M3),
            )
            torch.testing.assert_close(
                x.grad,
                torch.full_like(x.grad, dx_val_ref),
                **dtype_tols(tex.DType.kFloat8E5M2),
            )

            # Check that scaling factors match expected
351
            w_amax_ref = max(w_vals[: step + 1])
352
353
354
355
356
            x_amax_ref = max(x_vals[: step + 1])
            dy_amax_ref = max(dy_vals[: step + 1])
            w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
            x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin)
            dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin)
357
358
359
            w_scale = model.get_quantizer("forward", 1).scale
            x_scale = model.get_quantizer("forward", 0).scale
            dy_scale = model.get_quantizer("backward", 0).scale
360
361
362
363
            torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref))
            torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
            torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))

364
365
    @pytest.mark.parametrize("init_dtype", _dtypes)
    @pytest.mark.parametrize("final_dtype", _dtypes)
366
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
367
368
369
    def test_dtype_cast(
        self,
        *,
370
        size: int = 32,
371
372
373
        init_dtype: torch.dtype,
        final_dtype: torch.dtype,
        device: torch.device = "cuda",
374
        quantization: Optional[str],
375
376
377
378
    ) -> None:
        """Check dtype cast functions"""

        # Skip invalid configurations
379
380
        maybe_skip_quantization(quantization, device=device)
        with_quantization = quantization is not None
381
382
383
384
385
386
387
388
389
390
391

        # Random data
        dtype = torch.float32
        if torch.float16 in (init_dtype, final_dtype):
            dtype = torch.float16
        if torch.bfloat16 in (init_dtype, final_dtype):
            dtype = torch.bfloat16
        w_ref, w_test = make_reference_and_test_tensors(
            (size, size),
            test_dtype=dtype,
            test_device=device,
392
            test_is_fp8=with_quantization,
393
394
395
        )

        # Construct operation
396
        with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
397
398
399
400
401
402
403
404
405
406
407
408
409
410
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test

        # Cast operation dtype
        if final_dtype == torch.float32:
            op.float()
        elif final_dtype == torch.float16:
            op.half()
        elif final_dtype == torch.bfloat16:
            op.bfloat16()

        # Check weights
411
        assert isinstance(op.weight, QuantizedTensor) == with_quantization
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        assert op.weight.dtype == final_dtype
        w_test = op.weight.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0)

        # Check forward and backward pass
        x = torch.zeros(
            (size, size),
            dtype=init_dtype,
            device=device,
            requires_grad=True,
        )
        y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == final_dtype
        assert x.grad.dtype == init_dtype
        assert op.weight.grad.dtype == final_dtype

    @pytest.mark.parametrize("model_dtype", _dtypes)
    @pytest.mark.parametrize("autocast_dtype", _dtypes)
431
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
432
433
434
    def test_pyt_autocast(
        self,
        *,
435
        size: int = 32,
436
437
438
        model_dtype: torch.dtype,
        autocast_dtype: torch.dtype,
        device: torch.device = "cuda",
439
440
        quantization: Optional[str],
        quantized_weights: bool = False,
441
442
443
444
445
    ) -> None:
        """Test with PyTorch autocast"""
        device = torch.device(device)

        # Skip invalid configurations
446
447
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization)
448
449

        # Construct operation
450
451
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
452
453
454
455
456
457
458
459
460
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)

        # Check forward and backward pass
        x = torch.zeros(
            (size, size),
            dtype=model_dtype,
            device=device,
            requires_grad=True,
        )
461
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
462
463
464
465
466
467
468
469
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
                y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == autocast_dtype
        assert x.grad.dtype == model_dtype
        assert op.weight.grad.dtype == model_dtype

        # Check forward and backward pass (swapped context order)
470
        if quantized_compute:
471
472
473
            x.grad = None
            op.weight.grad = None
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
474
                with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
475
476
477
478
479
480
                    y = op(x)
            y.backward(torch.zeros_like(y))
            assert y.dtype == autocast_dtype
            assert x.grad.dtype == model_dtype
            assert op.weight.grad.dtype == model_dtype

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497

class TestBasicOps:
    """Tests for individual operations"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
    @pytest.mark.parametrize("fp8", (False, True))
    def test_identity(
        self,
        *,
498
        in_shape: Iterable[int] = (1,),
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        dtype: torch.dtype,
        device: torch.device,
        fp8: bool,
    ) -> None:

        # Skip invalid configurations
        if fp8 and not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8 and torch.device(device).type != "cuda":
            pytest.skip("FP8 is only supported on CUDA devices")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            test_is_fp8=fp8,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Identity()
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Identity is exact
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

        # Make sure we are not trivially passing the test
        with pytest.raises(AssertionError):
            torch.testing.assert_close(y_test, -y_ref, **tols)
        with pytest.raises(AssertionError):
            torch.testing.assert_close(dx_test, -dx_ref, **tols)

    @pytest.mark.parametrize(
        "shapes",
        (
            ((1, 2, 3, 4), (2, 12)),
            ((5, 4, 3, 2), (-1, 6)),
            ((30,), (2, 3, -1)),
            ((6, 7), (3, -1, 7)),
        ),
    )
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("fp8", (False, True))
    def test_reshape(
        self,
        *,
        shapes: tuple[Iterable[int], Iterable[int]],
        dtype: torch.dtype,
562
563
        device: torch.device = "cuda",
        memory_format: torch.memory_format = torch.contiguous_format,
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        fp8: bool,
    ) -> None:
        in_shape, out_shape = shapes

        # Skip invalid configurations
        if memory_format == torch.channels_last and len(in_shape) != 4:
            pytest.skip("torch.channels_last only supports 4D tensors")
        if fp8 and not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8 and torch.device(device).type != "cuda":
            pytest.skip("FP8 is only supported on CUDA devices")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            test_is_fp8=fp8,
        )
        x_test = x_test.contiguous(memory_format=memory_format)
        x_test = x_test.detach().requires_grad_()
        dy_ref, dy_test = make_reference_and_test_tensors(
            x_ref.reshape(out_shape).size(),
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref.reshape(out_shape)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Reshape(out_shape)
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Reshape is exact
        y_test = y_test.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        dx_test = x_test.grad.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("size", (1, 7, 32))
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1)))
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", _devices)
    @pytest.mark.parametrize("fp8", (False, True))
    def test_bias(
        self,
        *,
        size: int,
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device,
        fp8: bool,
    ) -> None:

        # Make input and bias shapes consistent
        in_shape = list(in_shape)[:-1] + [size]

        # Skip invalid configurations
        if fp8 and not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8 and torch.device(device).type != "cuda":
            pytest.skip("FP8 is only supported on CUDA devices")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            test_is_fp8=fp8,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            size,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Bias(size, device=device, dtype=dtype)
        with torch.no_grad():
            op.bias.copy_(b_test)
            del b_test
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

680
    @pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
Tim Moon's avatar
Tim Moon committed
681
682
    @pytest.mark.parametrize("cast_forward", (False, True))
    @pytest.mark.parametrize("cast_backward", (False, True))
683
    def test_quantize(
684
685
        self,
        *,
686
        in_shape: Iterable[int] = (32, 32),
Tim Moon's avatar
Tim Moon committed
687
        dtype: torch.dtype = torch.bfloat16,
688
        device: torch.device = "cuda",
689
        quantization: str,
Tim Moon's avatar
Tim Moon committed
690
691
        cast_forward: bool,
        cast_backward: bool,
692
    ) -> None:
693
694
695
696
        """Quantize"""

        # Skip invalid configurations
        maybe_skip_quantization(quantization)
Tim Moon's avatar
Tim Moon committed
697
698
699
700
701
702
703
704
705

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
            test_is_fp8=True,
        )
706
        x_test = x_test.dequantize().requires_grad_()
Tim Moon's avatar
Tim Moon committed
707
708
709
710
711
712
713
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
            test_is_fp8=True,
        )
714
        dy_test = dy_test.dequantize()
Tim Moon's avatar
Tim Moon committed
715
716
717
718
719
720
721

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
722
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
723
724
725
726
727
        with te.fp8_autocast(fp8_recipe=recipe):
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check tensor types
728
729
        assert isinstance(y_test, QuantizedTensor) == cast_forward
        assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
Tim Moon's avatar
Tim Moon committed
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744

        # Check values
        tols = dict(rtol=0, atol=0)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

    def _test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
745
746
747
748
749
750
751
        quantization: Optional[str] = None,
        quantized_compute: bool = False,
        quantized_input: bool = False,
        quantized_weight: bool = False,
        quantized_output: bool = False,
        quantized_grad_output: bool = False,
        quantized_grad_input: bool = False,
Tim Moon's avatar
Tim Moon committed
752
753
754
        accumulate_into_main_grad: bool = False,
    ) -> None:
        """Helper function for tests with GEMM"""
755
756
757
758
759
760
761

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
762
763
764
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantization == "fp8" and quantized_output and not quantized_compute:
Tim Moon's avatar
Tim Moon committed
765
            pytest.skip("FP8 output is only supported with FP8 GEMMs")
766
        if quantization == "fp8" and quantized_grad_input and not quantized_compute:
Tim Moon's avatar
Tim Moon committed
767
            pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
768
769
770
771
        if quantization == "mxfp8" and quantized_output:
            pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
        if quantization == "mxfp8" and quantized_grad_input:
            pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
772
773
774
775
776
777

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
778
            test_is_fp8=(quantized_compute or quantized_input),
779
        )
780
781
782
        if isinstance(x_test, QuantizedTensor):
            with torch.no_grad():
                x_test = x_test.dequantize().requires_grad_()
783
784
785
786
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
            test_dtype=dtype,
            test_device=device,
787
            test_is_fp8=(quantized_compute or quantized_weight),
788
789
790
791
792
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
793
            test_is_fp8=(quantized_compute or quantized_grad_output),
794
795
            requires_grad=False,
        )
796
797
        if isinstance(dy_test, QuantizedTensor):
            dy_test = dy_test.dequantize()
798
799
800
801
802
803

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
804
805
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
806
807
808
809
810
811
812
813
814
815
816
            op = te_ops.BasicLinear(
                in_features,
                out_features,
                device=device,
                dtype=dtype,
                accumulate_into_main_grad=accumulate_into_main_grad,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
            op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
Tim Moon's avatar
Tim Moon committed
817
        forward = te_ops.Sequential(
818
            te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
Tim Moon's avatar
Tim Moon committed
819
            op,
820
            te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
Tim Moon's avatar
Tim Moon committed
821
        )
822
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
823
            y_test = forward(x_test)
824
825
826
827
828
829
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
830
831
        if quantized_compute or quantized_output or quantized_grad_input:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        if accumulate_into_main_grad:
            if op.weight.grad is not None:
                torch.testing.assert_close(
                    op.weight.grad,
                    torch.zeros_like(op.weight.grad),
                    rtol=0,
                    atol=0,
                )
            dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5
        else:
            dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(
                op.weight.main_grad,
                torch.full_like(op.weight.main_grad, 0.5),
                rtol=0,
                atol=0,
            )
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

857
858
    @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
Tim Moon's avatar
Tim Moon committed
859
    @pytest.mark.parametrize("dtype", _dtypes)
860
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
Tim Moon's avatar
Tim Moon committed
861
862
863
864
865
866
867
    @pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
    def test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
868
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
869
870
871
872
873
874
875
        accumulate_into_main_grad: bool,
    ) -> None:
        """GEMM"""
        self._test_basic_linear(
            weight_shape=weight_shape,
            in_shape=in_shape,
            dtype=dtype,
876
877
            quantization=quantization,
            quantized_compute=quantization is not None,
Tim Moon's avatar
Tim Moon committed
878
879
880
881
            accumulate_into_main_grad=accumulate_into_main_grad,
        )

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
882
883
884
885
886
887
888
889
    @pytest.mark.parametrize("quantization", ("fp8", "mxfp8"))
    @pytest.mark.parametrize("quantized_compute", (False, True))
    @pytest.mark.parametrize("quantized_input", (False, True))
    @pytest.mark.parametrize("quantized_weight", (False, True))
    @pytest.mark.parametrize("quantized_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_input", (False, True))
    def test_basic_linear_quantized(
Tim Moon's avatar
Tim Moon committed
890
891
        self,
        *,
892
893
894
895
896
897
898
        quantization: str,
        quantized_compute: bool,
        quantized_input: bool,
        quantized_weight: bool,
        quantized_output: bool,
        quantized_grad_output: bool,
        quantized_grad_input: bool,
Tim Moon's avatar
Tim Moon committed
899
900
901
902
    ) -> None:
        """GEMM with FP8 inputs and outputs"""
        self._test_basic_linear(
            dtype=torch.bfloat16,
903
904
905
906
907
908
909
            quantization=quantization,
            quantized_compute=quantized_compute,
            quantized_input=quantized_input,
            quantized_weight=quantized_weight,
            quantized_output=quantized_output,
            quantized_grad_output=quantized_grad_output,
            quantized_grad_input=quantized_grad_input,
Tim Moon's avatar
Tim Moon committed
910
911
        )

912
    @pytest.mark.parametrize("bias", (False, True))
913
914
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
    @pytest.mark.parametrize("quantized_weight", (False, True))
915
916
917
918
    def test_linear(
        self,
        *,
        bias: bool,
919
920
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
921
922
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
923
924
        quantization: Optional[str],
        quantized_weight: bool,
925
926
927
928
929
930
931
932
933
    ) -> None:
        """GEMM + bias"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
934
935
936
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
937
938
939
940
941
942

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
943
            test_is_fp8=quantized_compute,
944
        )
945
946
947
        if isinstance(x_test, QuantizedTensor):
            with torch.no_grad():
                x_test = x_test.dequantize().requires_grad_()
948
949
950
951
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
            test_dtype=dtype,
            test_device=device,
952
            test_is_fp8=(quantized_compute or quantized_weight),
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
973
974
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
975
976
977
978
979
980
981
982
983
984
985
986
987
            op = te_ops.Linear(
                in_features,
                out_features,
                bias=bias,
                device=device,
                dtype=dtype,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            if bias:
                op.bias.copy_(b_test)
            del w_test
            del b_test
988
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
989
990
991
992
993
994
995
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
996
997
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

1010
1011
    @pytest.mark.parametrize("weight_shape", ((7, 2), (32,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1012
1013
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1014
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
Tim Moon's avatar
Tim Moon committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
    def test_layer_norm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1024
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1025
1026
1027
1028
1029
1030
1031
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1032
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=(w_ref + 1 if zero_centered_gamma else w_ref),
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
1080
1081
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1082
1083
        forward = te_ops.Sequential(
            op,
1084
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1085
        )
1086
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1087
1088
1089
1090
1091
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1092
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

    def test_layer_norm_autocast(
        self,
        *,
        weight_shape: Iterable[int] = (32,),
        in_shape: Iterable[int] = (32,),
        dtype: torch.dtype = torch.float16,
        autocast_dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        eps: float = 0.3,
    ) -> None:
        """Layer norm with PyTorch autocast"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=w_ref,
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
        with torch.autocast(device, dtype=autocast_dtype):
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        assert y_test.dtype == autocast_dtype
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dx_test, x_ref.grad, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype))
        torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype))

1180
1181
    @pytest.mark.parametrize("weight_shape", ((19,), (64,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1182
1183
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1184
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
Tim Moon's avatar
Tim Moon committed
1185
1186
1187
1188
1189
1190
1191
1192
1193
    def test_rmsnorm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1194
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1195
1196
1197
1198
1199
1200
1201
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1202
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
        var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
        if zero_centered_gamma:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
        else:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.RMSNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
1242
1243
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1244
1245
        forward = te_ops.Sequential(
            op,
1246
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1247
        )
1248
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1249
1250
1251
1252
1253
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1254
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
    @pytest.mark.parametrize("fp8", (False, True))
    def test_add_in_place(
        self,
        *,
        in_shape: Iterable[int] = (1,),
        dtype: torch.dtype,
        device: torch.device,
        fp8: bool,
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1276
1277
1278
1279
1280
        """Add two tensors

        Join in compute graph.

        """
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340

        # Skip invalid configurations
        if fp8 and not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8 and torch.device(device).type != "cuda":
            pytest.skip("FP8 is only supported on CUDA devices")

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            test_is_fp8=fp8,
        )
        x2_ref, x2_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            test_is_fp8=fp8,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x2_ref.detach()
        y_ref += x1_ref
        dx1_ref = dy_ref
        dx2_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.AddInPlace()
        y_test = op(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
        if fp8:
            tols = dtype_tols(x1_test._fp8_dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)

    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
    @pytest.mark.parametrize("fp8", (False, True))
    def test_make_extra_output(
        self,
        *,
        in_shape: Iterable[int] = (1,),
        dtype: torch.dtype,
        device: torch.device,
        fp8: bool,
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1341
1342
1343
1344
1345
        """Output tensor twice

        Split in compute graph.

        """
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391

        # Skip invalid configurations
        if fp8 and not fp8_available:
            pytest.skip(reason_for_no_fp8)
        if fp8 and torch.device(device).type != "cuda":
            pytest.skip("FP8 is only supported on CUDA devices")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            test_is_fp8=fp8,
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = x_ref
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operation
        op = te_ops.MakeExtraOutput()
        y1_test, y2_test = op(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check results
        tols = dtype_tols(dtype)
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0)
        torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1392
    @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
1393
    @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
1394
    @pytest.mark.parametrize("dtype", _dtypes)
1395
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
1396
1397
1398
1399
1400
1401
1402
    def test_activation(
        self,
        *,
        activation: str,
        out_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1403
        quantization: Optional[str],
1404
1405
1406
1407
1408
1409
1410
1411
1412
    ) -> None:
        """Activation functions"""

        # Tensor dimensions
        in_shape = list(out_shape)
        if activation in ("geglu", "reglu", "swiglu"):
            in_shape[-1] *= 2

        # Skip invalid configurations
1413
1414
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1415
1416
1417
1418
1419
1420

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
1421
            test_is_fp8=quantized_compute,
1422
        )
1423
1424
1425
        if quantized_compute:
            with torch.no_grad():
                x_test = x_test.dequantize().requires_grad_()
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref: torch.Tensor
        if activation == "gelu":
            y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
        elif activation == "relu":
            y_ref = torch.nn.functional.relu(x_ref)
        elif activation == "geglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
        elif activation == "reglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.relu(x1) * x2
        elif activation == "swiglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.silu(x1) * x2
        else:
            raise ValueError(f"Unexpected activation function ({activation})")
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1453
        recipe = make_recipe(quantization)
1454
1455
1456
1457
1458
1459
1460
1461
1462
        make_op = dict(
            gelu=te_ops.GELU,
            relu=te_ops.ReLU,
            geglu=te_ops.GEGLU,
            reglu=te_ops.ReGLU,
            swiglu=te_ops.SwiGLU,
        )[activation]
        forward = te_ops.Sequential(
            make_op(),
1463
            te_ops.Quantize(forward=quantized_compute, backward=False),
1464
        )
1465
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1466
1467
1468
1469
1470
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1471
        if quantized_compute:
1472
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1473
1474
        if activation == "relu":
            tols = {"atol": 0, "rtol": 0}
1475
1476
1477
1478
1479
1480
1481
1482

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("dtype", _dtypes)
1483
1484
1485
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
    @pytest.mark.parametrize("quantize_forward", (False, True))
    @pytest.mark.parametrize("quantize_backward", (False, True))
1486
1487
1488
    def test_swiglu(
        self,
        *,
1489
        out_shape: Iterable[int] = (32, 32),
1490
1491
        dtype: torch.dtype,
        device: torch.device = "cuda",
1492
1493
1494
        quantization: Optional[str],
        quantize_forward: bool,
        quantize_backward: bool,
1495
1496
1497
1498
1499
1500
1501
    ):

        # Tensor dimensions
        in_shape = list(out_shape)
        in_shape[-1] *= 2

        # Skip invalid configurations
1502
1503
1504
1505
        quantized_compute = quantization is not None
        if not quantized_compute and (quantize_forward or quantize_backward):
            pytest.skip("Quantization scheme has not been provided")
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        x1, x2 = x_ref.chunk(2, dim=-1)
        y_ref = torch.nn.functional.silu(x1) * x2
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1526
        recipe = make_recipe(quantization)
1527
        forward = te_ops.Sequential(
1528
            te_ops.Quantize(forward=False, backward=quantize_backward),
1529
            te_ops.SwiGLU(),
1530
            te_ops.Quantize(forward=quantize_forward, backward=False),
1531
        )
1532
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1533
1534
1535
1536
1537
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1538
        if quantized_compute:
1539
1540
1541
1542
1543
1544
1545
1546
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557

class TestFusedOps:
    """Tests for fused operations"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

1558
1559
    @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
1560
    @pytest.mark.parametrize("dtype", _dtypes)
1561
1562
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
    @pytest.mark.parametrize("quantized_weight", (False, True))
1563
    def test_forward_linear_bias_activation(
1564
1565
1566
1567
1568
1569
1570
        self,
        *,
        bias: bool = True,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1571
1572
        quantization: Optional[str],
        quantized_weight: bool,
1573
    ) -> None:
1574
        """Forward GEMM + bias + activation"""
1575
1576
1577
1578
1579
1580
1581

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1582
1583
1584
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
        if dtype not in (torch.float16, torch.bfloat16):
            pytest.skip(
                "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output"
            )

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
1595
            test_is_fp8=quantized_compute,
1596
        )
1597
1598
1599
        if quantized_compute:
            with torch.no_grad():
                x_test = x_test.dequantize().requires_grad_()
1600
1601
1602
1603
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
            test_dtype=dtype,
            test_device=device,
1604
            test_is_fp8=(quantized_compute or quantized_weight),
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1625
1626
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1642
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
            y_test = model(x_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasActivation)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1655
1656
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

1669
1670
    @pytest.mark.parametrize("bias", (False, True))
    @pytest.mark.parametrize("dtype", _dtypes)
1671
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
1672
1673
1674
1675
    def test_forward_linear_bias_add(
        self,
        *,
        bias: bool,
1676
1677
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1678
1679
        dtype: torch.dtype,
        device: torch.device = "cuda",
1680
1681
        quantization: Optional[str],
        quantized_weight: bool = False,
1682
1683
1684
1685
1686
1687
1688
1689
1690
    ) -> None:
        """Forward GEMM + bias + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1691
1692
1693
1694
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
1695
1696
1697
1698
1699
1700
1701
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
1702
            test_is_fp8=quantized_compute,
1703
        )
1704
1705
1706
        if isinstance(x1_test, QuantizedTensor):
            with torch.no_grad():
                x1_test = x1_test.dequantize().requires_grad_()
1707
1708
1709
1710
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
            test_dtype=dtype,
            test_device=device,
1711
            test_is_fp8=(quantized_compute or quantized_weight),
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        x2_ref, x2_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1737
1738
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
                te_ops.AddInPlace(),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1755
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
            y_test = model(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1768
1769
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
        torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

    @pytest.mark.parametrize("dtype", _dtypes)
1785
    @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
1786
1787
1788
    def test_backward_linear_add(
        self,
        *,
1789
1790
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1791
1792
        dtype: torch.dtype,
        device: torch.device = "cuda",
1793
1794
        quantization: Optional[str],
        quantized_weight: bool = False,
1795
1796
1797
1798
1799
1800
1801
1802
1803
    ) -> None:
        """Backward dgrad GEMM + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1804
1805
1806
1807
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
1808
1809
1810
1811
1812
1813
1814
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
1815
            test_is_fp8=quantized_compute,
1816
        )
1817
1818
1819
        if isinstance(x_test, QuantizedTensor):
            with torch.no_grad():
                x_test = x_test.dequantize().requires_grad_()
1820
1821
1822
1823
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
            test_dtype=dtype,
            test_device=device,
1824
            test_is_fp8=(quantized_compute or quantized_weight),
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = torch.nn.functional.linear(x_ref, w_ref)
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operations
1845
1846
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight):
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
            model = te_ops.Sequential(
                te_ops.MakeExtraOutput(),
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=False,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[1].weight.copy_(w_test)
            del w_test
1860
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
            y1_test, y2_test = model(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check that backward operations have been fused
        backward_ops = model._module_groups[0]._backward_ops
        assert len(backward_ops) == 1
        assert isinstance(backward_ops[0][0], BackwardLinearAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1873
1874
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884

        # Check results
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, **tols)
        torch.testing.assert_close(y2_test, y2_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)