test_fusible_ops.py 74.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

from __future__ import annotations

7
from collections.abc import Iterable
8
import io
9
import math
10
11
import pathlib
import sys
12
from typing import Optional
13
14
15

import pytest
import torch
yuguo's avatar
yuguo committed
16
from torch.utils.cpp_extension import IS_HIP_EXTENSION
17
18

import transformer_engine
Tim Moon's avatar
Tim Moon committed
19
import transformer_engine.common.recipe
20
21
22
23
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
24
25
from transformer_engine.pytorch.ops.fused import (
    BackwardLinearAdd,
26
    ForwardLinearBiasActivation,
27
    ForwardLinearBiasAdd,
28
)
29
from transformer_engine.pytorch.tensor import QuantizedTensor
30
31
32
33
34
35
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8Tensor,
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
36
37
38
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex

39
40
41
42
43
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe

yuguo's avatar
yuguo committed
44
45
46
47
48
49
50
51
if IS_HIP_EXTENSION:
    import os
    from functools import cache
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )

52
53
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
54
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
55
56
57
58
59
60
61
62
63

# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    _dtypes.append(torch.bfloat16)

# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]

64
65
66
67
68
69
70
# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
    _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
    _quantization_list.append("mxfp8")

71

72
73
74
75
76
77
def maybe_skip_quantization(
    quantization: Optional[str],
    *,
    dims: Optional[Iterable[int] | int] = None,
    device: Optional[torch.device | str] = None,
) -> None:
78
    """Skip test case if a quantization scheme is not supported"""
79
80
81
82
83
84

    # Don't skip if there is no quantization
    if quantization is None:
        return

    # Check if quantization scheme is supported
85
    if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
86
87
88
89
90
91
92
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    if dims is not None:
        if not isinstance(dims, Iterable):
            dims = (dims,)
93
        if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
94
95
96
97
98
99
100
101
102
103
104
            if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
                pytest.skip("FP8 GEMMs require dims that are divisible by 16")
        elif quantization == "mxfp8":
            if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
                pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")

    # Check if device is supported
    if device is not None and torch.device(device).type != "cuda":
        pytest.skip("Quantization is only supported on CUDA devices")


105
106
107
@torch.no_grad()
def make_reference_and_test_tensors(
    shape: int | Iterable[int],
108
    quantization: Optional[str] = None,
109
110
111
112
    ref_dtype: torch.dtype = torch.float64,
    ref_device: torch.device = "cpu",
    test_dtype: torch.dtype = torch.float32,
    test_device: torch.device = "cuda",
113
    test_is_quantized: bool = False,
114
115
116
117
118
119
120
121
    requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct tensors with the same values

    The reference tensor is intended for use in plain PyTorch
    operations in high precision. The test tensor is intended for use
    in Transformer Engine operations.

122
123
124
    If a quantization scheme is provided, the tensor values are
    quantized so that they are representable.

125
    """
126
127

    # Random reference tensor
128
    ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
129
130

    # Construct test tensor from reference tensor
131
    test = ref.to(device=test_device, dtype=test_dtype)
132
133
134
135
136
137
    if quantization is None:
        if test_is_quantized:
            raise ValueError("Quantization scheme not provided")
        if test.data_ptr() == ref.data_ptr():
            test = test.clone()
    elif quantization in ("fp8", "fp8_delayed_scaling"):
138
139
140
141
142
143
        quantizer = Float8Quantizer(
            scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
            amax=torch.zeros(1, dtype=torch.float32, device=test_device),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        test = quantizer(test)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    elif quantization == "fp8_current_scaling":
        quantizer = Float8CurrentScalingQuantizer(
            fp8_dtype=tex.DType.kFloat8E4M3,
            device=test_device,
        )
        test = quantizer(test)
    elif quantization == "mxfp8":
        test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
    else:
        raise ValueError(f"Unsupported quantization scheme ({quantization})")
    if isinstance(test, QuantizedTensor) and not test_is_quantized:
        test = test.dequantize()

    # Make sure reference and test tensors match each other
158
    ref.copy_(test)
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    ref.requires_grad_(requires_grad)
    test.requires_grad_(requires_grad)
    return ref, test


class TestSequential:
    """Tests for sequential container"""

    def test_modules(self) -> None:
        """Check that list of modules can be manipulated as expected"""

        # Construct sequential container
        modules = [
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
        ]
        model = te_ops.Sequential(*modules)

        # Length
        assert len(model) == len(modules)

        # Iterator
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Index by int
        for i, module in enumerate(modules):
            assert model[i] is module
            assert model[i - len(modules)] is module

        # Index by slice
        model_subset = model[1:-1]
        modules_subset = modules[1:-1]
        assert isinstance(model_subset, te_ops.Sequential)
        for module1, module2 in zip(model_subset, modules_subset):
            assert module1 is module2

        # Set element
        new_module = torch.nn.Identity()
        idx = 1
        modules[idx] = new_module
        model[idx] = new_module
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Delete element
        idx = 1
        del modules[idx]
        del model[idx]
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Append
        new_module = torch.nn.Identity()
        modules.append(new_module)
        model.append(new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Extend
        new_modules = [te_ops.Identity(), te_ops.Identity()]
        modules.extend(new_modules)
        model.extend(new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Insert
        new_module = te_ops.Identity()
        idx = 2
        modules.insert(idx, new_module)
        model.insert(idx, new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Pop
        idx = 2
        assert model.pop(idx) is modules.pop(idx)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Out-of-place add
        new_modules = [torch.nn.Identity(), te_ops.Identity()]
        added_modules = modules + new_modules
        added_model = model + te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2
        for module1, module2 in zip(added_model, added_modules):
            assert module1 is module2

        # In-place add
        new_modules = [te_ops.Identity(), torch.nn.Identity()]
        modules += new_modules
        model += te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

    def test_module_groups(self) -> None:
        """Check that modules are grouped together correctly"""
        model = te_ops.Sequential(
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
        )
        model(torch.zeros(1))
        assert len(model._module_groups) == 6


class TestFuser:
    """Tests for operation fusion infrastructure"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
    def test_fp8_scale_update(
        self,
        size: int = 16,
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
    ):
        """Test FP8 scaling factors with delayed scaling recipe"""

        # FP8 recipe
        margin = 2
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=8,
            amax_compute_algo="max",
        )

        # Construct model
305
        with te.fp8_model_init(recipe=recipe):
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            model = te_ops.basic.BasicLinear(
                size,
                size,
                device=device,
                dtype=dtype,
            )

        # Training steps
        w_vals = [2, 5, 3, 11]
        x_vals = [7, 3, 5]
        dy_vals = [1, 2, 1]
        with torch.no_grad():
            model.weight.fill_(w_vals[0])
        for step in range(3):

            # Data tensors
            x = torch.full(
                (size, size),
                x_vals[step],
                dtype=dtype,
                device=device,
                requires_grad=True,
            )
            dy = torch.full(
                (size, size),
                dy_vals[step],
                dtype=dtype,
                device=device,
            )

            # Training step
            with te.fp8_autocast(fp8_recipe=recipe):
                y = model(x)
            y.backward(dy)
            with torch.no_grad():
                model.weight.fill_(w_vals[step + 1])

            # Check that output tensors match expected
            tols = dict(rtol=0, atol=0)
            y_val_ref = w_vals[step] * x_vals[step] * size
            dx_val_ref = w_vals[step] * dy_vals[step] * size
            torch.testing.assert_close(
                y,
                torch.full_like(y, y_val_ref),
                **dtype_tols(tex.DType.kFloat8E4M3),
            )
            torch.testing.assert_close(
                x.grad,
                torch.full_like(x.grad, dx_val_ref),
                **dtype_tols(tex.DType.kFloat8E5M2),
            )

            # Check that scaling factors match expected
359
            w_amax_ref = max(w_vals[: step + 1])
360
361
362
363
364
            x_amax_ref = max(x_vals[: step + 1])
            dy_amax_ref = max(dy_vals[: step + 1])
            w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
            x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin)
            dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin)
365
366
367
            w_scale = model.get_quantizer("forward", 1).scale
            x_scale = model.get_quantizer("forward", 0).scale
            dy_scale = model.get_quantizer("backward", 0).scale
368
369
370
371
            torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref))
            torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
            torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))

372
373
    @pytest.mark.parametrize("init_dtype", _dtypes)
    @pytest.mark.parametrize("final_dtype", _dtypes)
374
    @pytest.mark.parametrize("quantization", _quantization_list)
375
376
377
    def test_dtype_cast(
        self,
        *,
378
        size: int = 32,
379
380
381
        init_dtype: torch.dtype,
        final_dtype: torch.dtype,
        device: torch.device = "cuda",
382
        quantization: Optional[str],
383
384
385
386
    ) -> None:
        """Check dtype cast functions"""

        # Skip invalid configurations
387
        in_shape = (size, size)
388
        with_quantization = quantization is not None
389
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
390
391
392
393
394
395
396
397
398

        # Random data
        dtype = torch.float32
        if torch.float16 in (init_dtype, final_dtype):
            dtype = torch.float16
        if torch.bfloat16 in (init_dtype, final_dtype):
            dtype = torch.bfloat16
        w_ref, w_test = make_reference_and_test_tensors(
            (size, size),
399
            quantization=quantization,
400
401
402
403
404
            test_dtype=dtype,
            test_device=device,
        )

        # Construct operation
405
        with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
406
407
408
409
410
411
412
413
414
415
416
417
418
419
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test

        # Cast operation dtype
        if final_dtype == torch.float32:
            op.float()
        elif final_dtype == torch.float16:
            op.half()
        elif final_dtype == torch.bfloat16:
            op.bfloat16()

        # Check weights
420
        assert isinstance(op.weight, QuantizedTensor) == with_quantization
421
422
        assert op.weight.dtype == final_dtype
        w_test = op.weight.to(dtype=torch.float64, device="cpu")
423
        torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype))
424
425
426

        # Check forward and backward pass
        x = torch.zeros(
427
            in_shape,
428
429
430
431
432
433
434
435
436
437
438
439
            dtype=init_dtype,
            device=device,
            requires_grad=True,
        )
        y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == final_dtype
        assert x.grad.dtype == init_dtype
        assert op.weight.grad.dtype == final_dtype

    @pytest.mark.parametrize("model_dtype", _dtypes)
    @pytest.mark.parametrize("autocast_dtype", _dtypes)
440
    @pytest.mark.parametrize("quantization", _quantization_list)
441
442
443
    def test_pyt_autocast(
        self,
        *,
444
        size: int = 32,
445
446
447
        model_dtype: torch.dtype,
        autocast_dtype: torch.dtype,
        device: torch.device = "cuda",
448
449
        quantization: Optional[str],
        quantized_weights: bool = False,
450
451
452
453
454
    ) -> None:
        """Test with PyTorch autocast"""
        device = torch.device(device)

        # Skip invalid configurations
455
        in_shape = (size, size)
456
        quantized_compute = quantization is not None
457
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
458
459

        # Construct operation
460
461
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
462
463
464
465
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)

        # Check forward and backward pass
        x = torch.zeros(
466
            in_shape,
467
468
469
470
            dtype=model_dtype,
            device=device,
            requires_grad=True,
        )
471
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
472
473
474
475
476
477
478
479
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
                y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == autocast_dtype
        assert x.grad.dtype == model_dtype
        assert op.weight.grad.dtype == model_dtype

        # Check forward and backward pass (swapped context order)
480
        if quantized_compute:
481
482
483
            x.grad = None
            op.weight.grad = None
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
484
                with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
485
486
487
488
489
490
                    y = op(x)
            y.backward(torch.zeros_like(y))
            assert y.dtype == autocast_dtype
            assert x.grad.dtype == model_dtype
            assert op.weight.grad.dtype == model_dtype

491
492
493
494
495
496
497
498
499
500
501
502
503

class TestBasicOps:
    """Tests for individual operations"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
504
    @pytest.mark.parametrize("quantization", _quantization_list)
505
506
507
    def test_identity(
        self,
        *,
508
        in_shape: Iterable[int] = (32, 32),
509
510
        dtype: torch.dtype,
        device: torch.device,
511
        quantization: Optional[str],
512
513
514
    ) -> None:

        # Skip invalid configurations
515
516
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
517
518
519
520

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
521
            quantization=quantization,
522
523
            test_dtype=dtype,
            test_device=device,
524
            test_is_quantized=with_quantization,
525
526
527
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
528
            quantization=quantization,
529
530
            test_dtype=dtype,
            test_device=device,
531
            test_is_quantized=with_quantization,
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Identity()
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Identity is exact
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

        # Make sure we are not trivially passing the test
        with pytest.raises(AssertionError):
            torch.testing.assert_close(y_test, -y_ref, **tols)
        with pytest.raises(AssertionError):
            torch.testing.assert_close(dx_test, -dx_ref, **tols)

    @pytest.mark.parametrize(
        "shapes",
        (
            ((1, 2, 3, 4), (2, 12)),
            ((5, 4, 3, 2), (-1, 6)),
            ((30,), (2, 3, -1)),
            ((6, 7), (3, -1, 7)),
        ),
    )
    @pytest.mark.parametrize("dtype", _dtypes)
567
    @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
568
569
570
571
572
    def test_reshape(
        self,
        *,
        shapes: tuple[Iterable[int], Iterable[int]],
        dtype: torch.dtype,
573
574
        device: torch.device = "cuda",
        memory_format: torch.memory_format = torch.contiguous_format,
575
        quantization: Optional[str],
576
577
578
579
580
581
    ) -> None:
        in_shape, out_shape = shapes

        # Skip invalid configurations
        if memory_format == torch.channels_last and len(in_shape) != 4:
            pytest.skip("torch.channels_last only supports 4D tensors")
582
583
        maybe_skip_quantization(quantization, device=device)
        with_quantization = quantization is not None
584
585
586
587

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
588
            quantization=quantization,
589
590
            test_dtype=dtype,
            test_device=device,
591
            test_is_quantized=with_quantization,
592
593
594
595
596
        )
        x_test = x_test.contiguous(memory_format=memory_format)
        x_test = x_test.detach().requires_grad_()
        dy_ref, dy_test = make_reference_and_test_tensors(
            x_ref.reshape(out_shape).size(),
597
            quantization=quantization,
598
599
            test_dtype=dtype,
            test_device=device,
600
            test_is_quantized=with_quantization,
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref.reshape(out_shape)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Reshape(out_shape)
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Reshape is exact
        y_test = y_test.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        dx_test = x_test.grad.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("size", (1, 7, 32))
629
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1)))
630
631
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", _devices)
632
    @pytest.mark.parametrize("quantization", _quantization_list)
633
634
635
636
637
638
639
    def test_bias(
        self,
        *,
        size: int,
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device,
640
        quantization: Optional[str],
641
642
643
644
645
646
    ) -> None:

        # Make input and bias shapes consistent
        in_shape = list(in_shape)[:-1] + [size]

        # Skip invalid configurations
647
648
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
649
650
651
652

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
653
            quantization=quantization,
654
655
            test_dtype=dtype,
            test_device=device,
656
            test_is_quantized=with_quantization,
657
658
659
660
661
662
663
664
        )
        b_ref, b_test = make_reference_and_test_tensors(
            size,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
665
            quantization=quantization,
666
667
            test_dtype=dtype,
            test_device=device,
668
            test_is_quantized=with_quantization,
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Bias(size, device=device, dtype=dtype)
        with torch.no_grad():
            op.bias.copy_(b_test)
            del b_test
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

693
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
694
695
    @pytest.mark.parametrize("cast_forward", (False, True))
    @pytest.mark.parametrize("cast_backward", (False, True))
696
    def test_quantize(
697
698
        self,
        *,
699
        in_shape: Iterable[int] = (32, 32),
Tim Moon's avatar
Tim Moon committed
700
        dtype: torch.dtype = torch.bfloat16,
701
        device: torch.device = "cuda",
702
        quantization: str,
Tim Moon's avatar
Tim Moon committed
703
704
        cast_forward: bool,
        cast_backward: bool,
705
    ) -> None:
706
707
708
        """Quantize"""

        # Skip invalid configurations
709
710
711
712
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, device=device)
        if quantization == "mxfp8":
            maybe_skip_quantization(quantization, dims=in_shape)
Tim Moon's avatar
Tim Moon committed
713
714
715
716

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
717
            quantization=quantization,
Tim Moon's avatar
Tim Moon committed
718
719
            test_dtype=dtype,
            test_device=device,
720
            requires_grad=True,
Tim Moon's avatar
Tim Moon committed
721
722
723
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
724
            quantization=quantization,
Tim Moon's avatar
Tim Moon committed
725
726
727
728
729
730
731
732
733
734
735
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
736
        recipe = make_recipe(quantization)
737
        with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
738
739
740
741
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check tensor types
742
743
744
        if with_quantization:
            assert isinstance(y_test, QuantizedTensor) == cast_forward
            assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
Tim Moon's avatar
Tim Moon committed
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759

        # Check values
        tols = dict(rtol=0, atol=0)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

    def _test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
760
761
762
763
764
765
766
        quantization: Optional[str] = None,
        quantized_compute: bool = False,
        quantized_input: bool = False,
        quantized_weight: bool = False,
        quantized_output: bool = False,
        quantized_grad_output: bool = False,
        quantized_grad_input: bool = False,
Tim Moon's avatar
Tim Moon committed
767
768
769
        accumulate_into_main_grad: bool = False,
    ) -> None:
        """Helper function for tests with GEMM"""
770
771
772
773
774
775
776

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
777
778
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
        quantization_needed = any(
            (
                quantized_compute,
                quantized_input,
                quantized_weight,
                quantized_output,
                quantized_grad_output,
                quantized_grad_input,
            )
        )
        if quantization is None and quantization_needed:
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not quantization_needed:
            pytest.skip("Quantization scheme is not used")
        if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
            if quantized_output and not quantized_compute:
                pytest.skip("FP8 output is only supported with FP8 GEMMs")
            if quantized_grad_input and not quantized_compute:
                pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
798
799
800
801
        if quantization == "mxfp8" and quantized_output:
            pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
        if quantization == "mxfp8" and quantized_grad_input:
            pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
yuguo's avatar
yuguo committed
802
803
804
        if ( IS_HIP_EXTENSION and not use_hipblaslt() and
            accumulate_into_main_grad and dtype != torch.float32 and not quantized_compute):
            pytest.skip("Parameters combination is not supported by ROCBLAS")
805
806
807
808

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
809
            quantization=quantization,
810
811
            test_dtype=dtype,
            test_device=device,
812
            test_is_quantized=quantized_input,
813
814
815
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
816
            quantization=quantization,
817
818
819
820
821
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
822
            quantization=quantization,
823
824
            test_dtype=dtype,
            test_device=device,
825
            test_is_quantized=quantized_grad_output,
826
827
828
829
830
831
832
833
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
834
835
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
836
837
838
839
840
841
842
843
844
845
846
            op = te_ops.BasicLinear(
                in_features,
                out_features,
                device=device,
                dtype=dtype,
                accumulate_into_main_grad=accumulate_into_main_grad,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
            op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
Tim Moon's avatar
Tim Moon committed
847
        forward = te_ops.Sequential(
848
            te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
Tim Moon's avatar
Tim Moon committed
849
            op,
850
            te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
Tim Moon's avatar
Tim Moon committed
851
        )
852
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
853
            y_test = forward(x_test)
854
855
856
857
858
859
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
860
861
        if quantized_compute or quantized_output or quantized_grad_input:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        if accumulate_into_main_grad:
            if op.weight.grad is not None:
                torch.testing.assert_close(
                    op.weight.grad,
                    torch.zeros_like(op.weight.grad),
                    rtol=0,
                    atol=0,
                )
            dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5
        else:
            dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(
                op.weight.main_grad,
                torch.full_like(op.weight.main_grad, 0.5),
                rtol=0,
                atol=0,
            )
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

887
888
    @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
Tim Moon's avatar
Tim Moon committed
889
    @pytest.mark.parametrize("dtype", _dtypes)
890
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
891
892
893
894
895
896
897
    @pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
    def test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
898
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
899
900
901
902
903
904
905
        accumulate_into_main_grad: bool,
    ) -> None:
        """GEMM"""
        self._test_basic_linear(
            weight_shape=weight_shape,
            in_shape=in_shape,
            dtype=dtype,
906
907
            quantization=quantization,
            quantized_compute=quantization is not None,
Tim Moon's avatar
Tim Moon committed
908
909
910
911
            accumulate_into_main_grad=accumulate_into_main_grad,
        )

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
912
    @pytest.mark.parametrize("quantization", _quantization_list)
913
914
915
916
917
918
919
    @pytest.mark.parametrize("quantized_compute", (False, True))
    @pytest.mark.parametrize("quantized_input", (False, True))
    @pytest.mark.parametrize("quantized_weight", (False, True))
    @pytest.mark.parametrize("quantized_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_input", (False, True))
    def test_basic_linear_quantized(
Tim Moon's avatar
Tim Moon committed
920
921
        self,
        *,
922
923
924
925
926
927
928
        quantization: str,
        quantized_compute: bool,
        quantized_input: bool,
        quantized_weight: bool,
        quantized_output: bool,
        quantized_grad_output: bool,
        quantized_grad_input: bool,
Tim Moon's avatar
Tim Moon committed
929
930
    ) -> None:
        """GEMM with FP8 inputs and outputs"""
931
932
        if quantization is None:
            pytest.skip("Skipping case without quantization")
Tim Moon's avatar
Tim Moon committed
933
934
        self._test_basic_linear(
            dtype=torch.bfloat16,
935
936
937
938
939
940
941
            quantization=quantization,
            quantized_compute=quantized_compute,
            quantized_input=quantized_input,
            quantized_weight=quantized_weight,
            quantized_output=quantized_output,
            quantized_grad_output=quantized_grad_output,
            quantized_grad_input=quantized_grad_input,
Tim Moon's avatar
Tim Moon committed
942
943
        )

944
    @pytest.mark.parametrize("bias", (False, True))
945
946
    @pytest.mark.parametrize("quantization", _quantization_list)
    @pytest.mark.parametrize("quantized_compute", (False, True))
947
    @pytest.mark.parametrize("quantized_weight", (False, True))
948
949
    @pytest.mark.parametrize("input_requires_grad", (False, True))
    @pytest.mark.parametrize("weight_requires_grad", (False, True))
950
951
952
953
    def test_linear(
        self,
        *,
        bias: bool,
954
955
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
956
957
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
958
        quantization: Optional[str],
959
        quantized_compute: bool,
960
        quantized_weight: bool,
961
962
        input_requires_grad: bool,
        weight_requires_grad: bool,
963
964
965
966
967
968
969
970
971
    ) -> None:
        """GEMM + bias"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
972
973
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
974
975
976
977
        if quantization is None and (quantized_compute or quantized_weight):
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not (quantized_compute or quantized_weight):
            pytest.skip("Quantization scheme is not used")
978
979
980
981

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
982
            quantization=quantization,
983
984
985
986
987
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
988
            quantization=quantization,
989
990
991
992
993
994
995
996
997
998
999
1000
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1001
            quantization=quantization,
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1012
1013
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
            op = te_ops.Linear(
                in_features,
                out_features,
                bias=bias,
                device=device,
                dtype=dtype,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            if bias:
                op.bias.copy_(b_test)
            del w_test
            del b_test
1027
1028
            for param in op.parameters():
                param.requires_grad_(requires_grad=weight_requires_grad)
1029
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1030
            y_test = op(x_test)
1031
1032
        if input_requires_grad or weight_requires_grad:
            y_test.backward(dy_test)
1033
1034
1035
1036
1037

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1038
1039
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1040
1041
1042
1043

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
1044
1045
1046
1047
1048
1049
1050
1051
1052
        if input_requires_grad:
            dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        if weight_requires_grad:
            dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(dw_test, w_ref.grad, **tols)
            if bias:
                db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
                torch.testing.assert_close(db_test, b_ref.grad, **tols)
1053

1054
1055
    @pytest.mark.parametrize("weight_shape", ((7, 2), (32,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1056
1057
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1058
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
1059
1060
1061
1062
1063
1064
1065
1066
1067
    def test_layer_norm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1068
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1069
1070
1071
1072
1073
1074
1075
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1076
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=(w_ref + 1 if zero_centered_gamma else w_ref),
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
1124
1125
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1126
1127
        forward = te_ops.Sequential(
            op,
1128
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1129
        )
1130
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1131
1132
1133
1134
1135
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1136
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

    def test_layer_norm_autocast(
        self,
        *,
        weight_shape: Iterable[int] = (32,),
        in_shape: Iterable[int] = (32,),
        dtype: torch.dtype = torch.float16,
        autocast_dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        eps: float = 0.3,
    ) -> None:
        """Layer norm with PyTorch autocast"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=w_ref,
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
        with torch.autocast(device, dtype=autocast_dtype):
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        assert y_test.dtype == autocast_dtype
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dx_test, x_ref.grad, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype))
        torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype))

1224
1225
    @pytest.mark.parametrize("weight_shape", ((19,), (64,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1226
1227
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1228
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
1229
1230
1231
1232
1233
1234
1235
1236
1237
    def test_rmsnorm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1238
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1239
1240
1241
1242
1243
1244
1245
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1246
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
        var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
        if zero_centered_gamma:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
        else:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.RMSNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
1286
1287
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1288
1289
        forward = te_ops.Sequential(
            op,
1290
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1291
        )
1292
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1293
1294
1295
1296
1297
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1298
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
    @pytest.mark.parametrize("in_shape", ((32,), (6, 16, 64), (32, 64)))
    @pytest.mark.parametrize("dtype", _dtypes)
    def test_l2normalization(
        self,
        *,
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 1e-6,
    ) -> None:
        """L2 Normalization"""

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        # L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps)
        l2_norm_squared = x_ref.pow(2).sum(dim=-1, keepdim=True)
        rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
        y_ref = x_ref * rsqrt_norm
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.L2Normalization(
            eps=eps,
        )
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")

        torch.testing.assert_close(y_test, y_ref, **tols)
        # L2Norm backward pass requires slightly looser atol for bfloat16
        if dtype == torch.bfloat16:
            tols["atol"] = 2e-3
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
Tim Moon's avatar
Tim Moon committed
1360

1361
1362
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
1363
    @pytest.mark.parametrize("quantization", _quantization_list)
1364
1365
1366
    def test_add_in_place(
        self,
        *,
1367
        in_shape: Iterable[int] = (32, 32),
1368
1369
        dtype: torch.dtype,
        device: torch.device,
1370
        quantization: Optional[str],
1371
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1372
1373
1374
1375
1376
        """Add two tensors

        Join in compute graph.

        """
1377
1378

        # Skip invalid configurations
1379
1380
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1381
1382
1383
1384

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
1385
            quantization=quantization,
1386
1387
            test_dtype=dtype,
            test_device=device,
1388
            test_is_quantized=with_quantization,
1389
1390
1391
        )
        x2_ref, x2_test = make_reference_and_test_tensors(
            in_shape,
1392
            quantization=quantization,
1393
1394
            test_dtype=dtype,
            test_device=device,
1395
            test_is_quantized=with_quantization,
1396
1397
1398
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
1399
            quantization=quantization,
1400
1401
            test_dtype=dtype,
            test_device=device,
1402
            test_is_quantized=with_quantization,
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x2_ref.detach()
        y_ref += x1_ref
        dx1_ref = dy_ref
        dx2_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.AddInPlace()
        y_test = op(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
1419
        if with_quantization:
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
            tols = dtype_tols(x1_test._fp8_dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)

    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
1430
    @pytest.mark.parametrize("quantization", _quantization_list)
1431
1432
1433
    def test_make_extra_output(
        self,
        *,
1434
        in_shape: Iterable[int] = (32, 32),
1435
1436
        dtype: torch.dtype,
        device: torch.device,
1437
        quantization: Optional[str],
1438
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1439
1440
1441
1442
1443
        """Output tensor twice

        Split in compute graph.

        """
1444
1445

        # Skip invalid configurations
1446
1447
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1448
1449
1450
1451

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1452
            quantization=quantization,
1453
1454
            test_dtype=dtype,
            test_device=device,
1455
            test_is_quantized=with_quantization,
1456
1457
1458
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            in_shape,
1459
            quantization=quantization,
1460
1461
            test_dtype=dtype,
            test_device=device,
1462
            test_is_quantized=with_quantization,
1463
1464
1465
1466
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            in_shape,
1467
            quantization=quantization,
1468
1469
            test_dtype=dtype,
            test_device=device,
1470
            test_is_quantized=with_quantization,
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = x_ref
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operation
        op = te_ops.MakeExtraOutput()
        y1_test, y2_test = op(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check results
        tols = dtype_tols(dtype)
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0)
        torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1493
    @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
1494
    @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
1495
    @pytest.mark.parametrize("dtype", _dtypes)
1496
    @pytest.mark.parametrize("quantization", _quantization_list)
1497
    @pytest.mark.parametrize("cache_quantized_input", (False, True))
1498
1499
1500
1501
1502
1503
1504
    def test_activation(
        self,
        *,
        activation: str,
        out_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1505
        quantization: Optional[str],
1506
        cache_quantized_input: bool,
1507
1508
1509
1510
1511
1512
1513
1514
1515
    ) -> None:
        """Activation functions"""

        # Tensor dimensions
        in_shape = list(out_shape)
        if activation in ("geglu", "reglu", "swiglu"):
            in_shape[-1] *= 2

        # Skip invalid configurations
1516
1517
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1518
        if cache_quantized_input:
1519
            maybe_skip_quantization("fp8_current_scaling", device=device)
1520
1521
1522
1523

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1524
            quantization="fp8_current_scaling" if cache_quantized_input else None,
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref: torch.Tensor
        if activation == "gelu":
            y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
        elif activation == "relu":
            y_ref = torch.nn.functional.relu(x_ref)
        elif activation == "geglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
        elif activation == "reglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.relu(x1) * x2
        elif activation == "swiglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.silu(x1) * x2
        else:
            raise ValueError(f"Unexpected activation function ({activation})")
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1555
        recipe = make_recipe(quantization)
1556
1557
1558
1559
1560
1561
1562
1563
        make_op = dict(
            gelu=te_ops.GELU,
            relu=te_ops.ReLU,
            geglu=te_ops.GEGLU,
            reglu=te_ops.ReGLU,
            swiglu=te_ops.SwiGLU,
        )[activation]
        forward = te_ops.Sequential(
1564
            te_ops.Quantize(forward=False, backward=quantized_compute),
1565
            make_op(cache_quantized_input=cache_quantized_input),
1566
            te_ops.Quantize(forward=quantized_compute, backward=False),
1567
        )
1568
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1569
1570
1571
1572
1573
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1574
        if quantized_compute or cache_quantized_input:
1575
1576
1577
1578
1579
1580
1581
1582
1583
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("dtype", _dtypes)
1584
    @pytest.mark.parametrize("quantization", _quantization_list)
1585
1586
    @pytest.mark.parametrize("quantize_forward", (False, True))
    @pytest.mark.parametrize("quantize_backward", (False, True))
1587
1588
1589
    def test_swiglu(
        self,
        *,
1590
        out_shape: Iterable[int] = (32, 32),
1591
1592
        dtype: torch.dtype,
        device: torch.device = "cuda",
1593
1594
1595
        quantization: Optional[str],
        quantize_forward: bool,
        quantize_backward: bool,
1596
1597
1598
1599
1600
1601
1602
    ):

        # Tensor dimensions
        in_shape = list(out_shape)
        in_shape[-1] *= 2

        # Skip invalid configurations
1603
1604
1605
1606
        quantized_compute = quantization is not None
        if not quantized_compute and (quantize_forward or quantize_backward):
            pytest.skip("Quantization scheme has not been provided")
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        x1, x2 = x_ref.chunk(2, dim=-1)
        y_ref = torch.nn.functional.silu(x1) * x2
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1627
        recipe = make_recipe(quantization)
1628
        forward = te_ops.Sequential(
1629
            te_ops.Quantize(forward=False, backward=quantize_backward),
1630
            te_ops.SwiGLU(),
1631
            te_ops.Quantize(forward=quantize_forward, backward=False),
1632
        )
1633
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1634
1635
1636
1637
1638
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1639
        if quantized_compute:
1640
1641
1642
1643
1644
1645
1646
1647
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658

class TestFusedOps:
    """Tests for fused operations"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

1659
1660
    @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
1661
    @pytest.mark.parametrize("dtype", _dtypes)
1662
    @pytest.mark.parametrize("quantization", _quantization_list)
1663
    @pytest.mark.parametrize("quantized_weight", (False, True))
1664
    def test_forward_linear_bias_activation(
1665
1666
1667
1668
1669
1670
1671
        self,
        *,
        bias: bool = True,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1672
1673
        quantization: Optional[str],
        quantized_weight: bool,
1674
    ) -> None:
1675
        """Forward GEMM + bias + activation"""
1676
1677
1678
1679
1680
1681
1682

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1683
1684
1685
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
1686
1687
1688
1689
1690
1691
1692
1693
        if dtype not in (torch.float16, torch.bfloat16):
            pytest.skip(
                "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output"
            )

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1694
            quantization=quantization,
1695
1696
1697
1698
1699
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1700
            quantization=quantization,
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1713
            quantization=quantization,
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1724
1725
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1741
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
            y_test = model(x_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasActivation)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1754
1755
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

1768
1769
    @pytest.mark.parametrize("bias", (False, True))
    @pytest.mark.parametrize("dtype", _dtypes)
1770
    @pytest.mark.parametrize("quantization", _quantization_list)
1771
1772
1773
1774
    def test_forward_linear_bias_add(
        self,
        *,
        bias: bool,
1775
1776
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1777
1778
        dtype: torch.dtype,
        device: torch.device = "cuda",
1779
1780
        quantization: Optional[str],
        quantized_weight: bool = False,
1781
1782
1783
1784
1785
1786
1787
1788
1789
    ) -> None:
        """Forward GEMM + bias + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1790
1791
1792
1793
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
1794
1795
1796
1797
1798
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
1799
            quantization=quantization,
1800
1801
1802
1803
1804
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1805
            quantization=quantization,
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        x2_ref, x2_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1823
            quantization=quantization,
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1834
1835
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
                te_ops.AddInPlace(),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1852
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
            y_test = model(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1865
1866
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
        torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

    @pytest.mark.parametrize("dtype", _dtypes)
1882
    @pytest.mark.parametrize("quantization", _quantization_list)
1883
1884
1885
    def test_backward_linear_add(
        self,
        *,
1886
1887
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1888
1889
        dtype: torch.dtype,
        device: torch.device = "cuda",
1890
1891
        quantization: Optional[str],
        quantized_weight: bool = False,
1892
1893
1894
1895
1896
1897
1898
1899
1900
    ) -> None:
        """Backward dgrad GEMM + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1901
1902
1903
1904
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
1905
1906
1907
1908
1909
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1910
            quantization=quantization,
1911
1912
1913
1914
1915
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1916
            quantization=quantization,
1917
1918
1919
1920
1921
            test_dtype=dtype,
            test_device=device,
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            out_shape,
1922
            quantization=quantization,
1923
1924
1925
1926
1927
1928
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            out_shape,
1929
            quantization=quantization,
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = torch.nn.functional.linear(x_ref, w_ref)
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operations
1941
1942
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight):
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
            model = te_ops.Sequential(
                te_ops.MakeExtraOutput(),
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=False,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[1].weight.copy_(w_test)
            del w_test
1956
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
            y1_test, y2_test = model(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check that backward operations have been fused
        backward_ops = model._module_groups[0]._backward_ops
        assert len(backward_ops) == 1
        assert isinstance(backward_ops[0][0], BackwardLinearAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1969
1970
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980

        # Check results
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, **tols)
        torch.testing.assert_close(y2_test, y2_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992


class TestCheckpointing:
    """Tests for checkpointing"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

1993
    @pytest.mark.parametrize("quantization", _quantization_list)
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
    @pytest.mark.parametrize("quantized_weight", (False, True))
    def test_linear(
        self,
        *,
        pre_checkpoint_steps: int = 2,
        post_checkpoint_steps: int = 2,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        quantization: Optional[str],
        quantized_weight: bool,
    ) -> None:
        """Check checkpointing with linear op"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)

        # Construct model
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            model_save = te_ops.Sequential(
                te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
            )
        optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25)

        # Warmup training steps
        for _ in range(pre_checkpoint_steps):
            x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
            dy = torch.randn(out_shape, dtype=dtype, device=device)
            optim_save.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_save(x)
            y.backward(dy)
            optim_save.step()

        # Save checkpoint
        byte_stream = io.BytesIO()
        torch.save(
            {"model": model_save.state_dict(), "optim": optim_save.state_dict()},
            byte_stream,
        )
        checkpoint_bytes = byte_stream.getvalue()
        del byte_stream

        # Synthetic data for evaluation
        xs_save = [
            torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
            for _ in range(post_checkpoint_steps)
        ]
        with torch.no_grad():
            xs_load = [x.clone().requires_grad_() for x in xs_save]
        dys = [
            torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps)
        ]

        # Training steps with original model
        ys_save = []
        for i in range(post_checkpoint_steps):
            optim_save.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_save(xs_save[i])
            y.backward(dys[i])
            optim_save.step()
            ys_save.append(y)

        # Load checkpoint
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            model_load = te_ops.Sequential(
                te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
            )
        optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
        state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
        model_load.load_state_dict(state_dict["model"])
        optim_load.load_state_dict(state_dict["optim"])

        # Training steps with loaded model
        ys_load = []
        for i in range(post_checkpoint_steps):
            optim_load.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_load(xs_load[i])
            y.backward(dys[i])
            optim_load.step()
            ys_load.append(y)

        # Check that original and loaded model match exactly
        tols = {"rtol": 0, "atol": 0}
        for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
            torch.testing.assert_close(param_load, param_save, **tols)
            torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
        for y_load, y_save in zip(ys_load, ys_save):
            torch.testing.assert_close(y_load, y_save, **tols)
        for x_load, x_save in zip(xs_load, xs_save):
            torch.testing.assert_close(x_load.grad, x_save.grad, **tols)