test_fusible_ops.py 77.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

from __future__ import annotations

7
from collections.abc import Iterable
8
import io
9
import math
10
11
import pathlib
import sys
12
from typing import Optional
13
14
15
16
17

import pytest
import torch

import transformer_engine
Tim Moon's avatar
Tim Moon committed
18
import transformer_engine.common.recipe
19
20
21
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
22
23
from transformer_engine.pytorch.ops.fused import (
    BackwardLinearAdd,
24
    ForwardLinearBiasActivation,
25
    ForwardLinearBiasAdd,
26
)
27
from transformer_engine.pytorch.tensor import QuantizedTensor
28
29
30
31
32
33
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8Tensor,
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
34
35
36
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex

37
38
39
40
41
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe

42
43
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
44
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
45
46
47
48
49
50
51
52
53

# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    _dtypes.append(torch.bfloat16)

# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]

54
55
56
57
58
59
60
# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
    _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
    _quantization_list.append("mxfp8")

61

62
63
64
65
66
67
def maybe_skip_quantization(
    quantization: Optional[str],
    *,
    dims: Optional[Iterable[int] | int] = None,
    device: Optional[torch.device | str] = None,
) -> None:
68
    """Skip test case if a quantization scheme is not supported"""
69
70
71
72
73
74

    # Don't skip if there is no quantization
    if quantization is None:
        return

    # Check if quantization scheme is supported
75
    if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
76
77
78
79
80
81
82
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    if dims is not None:
        if not isinstance(dims, Iterable):
            dims = (dims,)
83
        if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
84
85
86
87
88
89
90
91
92
93
94
            if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
                pytest.skip("FP8 GEMMs require dims that are divisible by 16")
        elif quantization == "mxfp8":
            if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
                pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")

    # Check if device is supported
    if device is not None and torch.device(device).type != "cuda":
        pytest.skip("Quantization is only supported on CUDA devices")


95
96
97
@torch.no_grad()
def make_reference_and_test_tensors(
    shape: int | Iterable[int],
98
    quantization: Optional[str] = None,
99
100
101
102
    ref_dtype: torch.dtype = torch.float64,
    ref_device: torch.device = "cpu",
    test_dtype: torch.dtype = torch.float32,
    test_device: torch.device = "cuda",
103
    test_is_quantized: bool = False,
104
105
106
107
108
109
110
111
    requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct tensors with the same values

    The reference tensor is intended for use in plain PyTorch
    operations in high precision. The test tensor is intended for use
    in Transformer Engine operations.

112
113
114
    If a quantization scheme is provided, the tensor values are
    quantized so that they are representable.

115
    """
116
117

    # Random reference tensor
118
    ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
119
120

    # Construct test tensor from reference tensor
121
    test = ref.to(device=test_device, dtype=test_dtype)
122
123
124
125
126
127
    if quantization is None:
        if test_is_quantized:
            raise ValueError("Quantization scheme not provided")
        if test.data_ptr() == ref.data_ptr():
            test = test.clone()
    elif quantization in ("fp8", "fp8_delayed_scaling"):
128
129
130
131
132
133
        quantizer = Float8Quantizer(
            scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
            amax=torch.zeros(1, dtype=torch.float32, device=test_device),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        test = quantizer(test)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    elif quantization == "fp8_current_scaling":
        quantizer = Float8CurrentScalingQuantizer(
            fp8_dtype=tex.DType.kFloat8E4M3,
            device=test_device,
        )
        test = quantizer(test)
    elif quantization == "mxfp8":
        test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
    else:
        raise ValueError(f"Unsupported quantization scheme ({quantization})")
    if isinstance(test, QuantizedTensor) and not test_is_quantized:
        test = test.dequantize()

    # Make sure reference and test tensors match each other
148
    ref.copy_(test)
149

150
151
152
153
154
    ref.requires_grad_(requires_grad)
    test.requires_grad_(requires_grad)
    return ref, test


155
class TestSequentialContainer:
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    """Tests for sequential container"""

    def test_modules(self) -> None:
        """Check that list of modules can be manipulated as expected"""

        # Construct sequential container
        modules = [
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
        ]
        model = te_ops.Sequential(*modules)

        # Length
        assert len(model) == len(modules)

        # Iterator
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Index by int
        for i, module in enumerate(modules):
            assert model[i] is module
            assert model[i - len(modules)] is module

        # Index by slice
        model_subset = model[1:-1]
        modules_subset = modules[1:-1]
        assert isinstance(model_subset, te_ops.Sequential)
        for module1, module2 in zip(model_subset, modules_subset):
            assert module1 is module2

        # Set element
        new_module = torch.nn.Identity()
        idx = 1
        modules[idx] = new_module
        model[idx] = new_module
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Delete element
        idx = 1
        del modules[idx]
        del model[idx]
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Append
        new_module = torch.nn.Identity()
        modules.append(new_module)
        model.append(new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Extend
        new_modules = [te_ops.Identity(), te_ops.Identity()]
        modules.extend(new_modules)
        model.extend(new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Insert
        new_module = te_ops.Identity()
        idx = 2
        modules.insert(idx, new_module)
        model.insert(idx, new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Pop
        idx = 2
        assert model.pop(idx) is modules.pop(idx)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Out-of-place add
        new_modules = [torch.nn.Identity(), te_ops.Identity()]
        added_modules = modules + new_modules
        added_model = model + te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2
        for module1, module2 in zip(added_model, added_modules):
            assert module1 is module2

        # In-place add
        new_modules = [te_ops.Identity(), torch.nn.Identity()]
        modules += new_modules
        model += te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

    def test_module_groups(self) -> None:
        """Check that modules are grouped together correctly"""
        model = te_ops.Sequential(
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
        )
        model(torch.zeros(1))
        assert len(model._module_groups) == 6


class TestFuser:
    """Tests for operation fusion infrastructure"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
    def test_fp8_scale_update(
        self,
        size: int = 16,
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
    ):
        """Test FP8 scaling factors with delayed scaling recipe"""

        # FP8 recipe
        margin = 2
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=8,
            amax_compute_algo="max",
        )

        # Construct model
295
        with te.fp8_model_init(recipe=recipe):
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            model = te_ops.basic.BasicLinear(
                size,
                size,
                device=device,
                dtype=dtype,
            )

        # Training steps
        w_vals = [2, 5, 3, 11]
        x_vals = [7, 3, 5]
        dy_vals = [1, 2, 1]
        with torch.no_grad():
            model.weight.fill_(w_vals[0])
        for step in range(3):

            # Data tensors
            x = torch.full(
                (size, size),
                x_vals[step],
                dtype=dtype,
                device=device,
                requires_grad=True,
            )
            dy = torch.full(
                (size, size),
                dy_vals[step],
                dtype=dtype,
                device=device,
            )

            # Training step
            with te.fp8_autocast(fp8_recipe=recipe):
                y = model(x)
            y.backward(dy)
            with torch.no_grad():
                model.weight.fill_(w_vals[step + 1])

            # Check that output tensors match expected
            tols = dict(rtol=0, atol=0)
            y_val_ref = w_vals[step] * x_vals[step] * size
            dx_val_ref = w_vals[step] * dy_vals[step] * size
            torch.testing.assert_close(
                y,
                torch.full_like(y, y_val_ref),
                **dtype_tols(tex.DType.kFloat8E4M3),
            )
            torch.testing.assert_close(
                x.grad,
                torch.full_like(x.grad, dx_val_ref),
                **dtype_tols(tex.DType.kFloat8E5M2),
            )

            # Check that scaling factors match expected
349
            w_amax_ref = max(w_vals[: step + 1])
350
351
352
353
354
            x_amax_ref = max(x_vals[: step + 1])
            dy_amax_ref = max(dy_vals[: step + 1])
            w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
            x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin)
            dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin)
355
356
357
            w_scale = model.get_quantizer("forward", 1).scale
            x_scale = model.get_quantizer("forward", 0).scale
            dy_scale = model.get_quantizer("backward", 0).scale
358
359
360
361
            torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref))
            torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
            torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))

362
363
    @pytest.mark.parametrize("init_dtype", _dtypes)
    @pytest.mark.parametrize("final_dtype", _dtypes)
364
    @pytest.mark.parametrize("quantization", _quantization_list)
365
366
367
    def test_dtype_cast(
        self,
        *,
368
        size: int = 32,
369
370
371
        init_dtype: torch.dtype,
        final_dtype: torch.dtype,
        device: torch.device = "cuda",
372
        quantization: Optional[str],
373
374
375
376
    ) -> None:
        """Check dtype cast functions"""

        # Skip invalid configurations
377
        in_shape = (size, size)
378
        with_quantization = quantization is not None
379
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
380
381
382
383
384
385
386
387
388

        # Random data
        dtype = torch.float32
        if torch.float16 in (init_dtype, final_dtype):
            dtype = torch.float16
        if torch.bfloat16 in (init_dtype, final_dtype):
            dtype = torch.bfloat16
        w_ref, w_test = make_reference_and_test_tensors(
            (size, size),
389
            quantization=quantization,
390
391
392
393
394
            test_dtype=dtype,
            test_device=device,
        )

        # Construct operation
395
        with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
396
397
398
399
400
401
402
403
404
405
406
407
408
409
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test

        # Cast operation dtype
        if final_dtype == torch.float32:
            op.float()
        elif final_dtype == torch.float16:
            op.half()
        elif final_dtype == torch.bfloat16:
            op.bfloat16()

        # Check weights
410
        assert isinstance(op.weight, QuantizedTensor) == with_quantization
411
412
        assert op.weight.dtype == final_dtype
        w_test = op.weight.to(dtype=torch.float64, device="cpu")
413
        torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype))
414
415
416

        # Check forward and backward pass
        x = torch.zeros(
417
            in_shape,
418
419
420
421
422
423
424
425
426
427
428
429
            dtype=init_dtype,
            device=device,
            requires_grad=True,
        )
        y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == final_dtype
        assert x.grad.dtype == init_dtype
        assert op.weight.grad.dtype == final_dtype

    @pytest.mark.parametrize("model_dtype", _dtypes)
    @pytest.mark.parametrize("autocast_dtype", _dtypes)
430
    @pytest.mark.parametrize("quantization", _quantization_list)
431
432
433
    def test_pyt_autocast(
        self,
        *,
434
        size: int = 32,
435
436
437
        model_dtype: torch.dtype,
        autocast_dtype: torch.dtype,
        device: torch.device = "cuda",
438
439
        quantization: Optional[str],
        quantized_weights: bool = False,
440
441
442
443
444
    ) -> None:
        """Test with PyTorch autocast"""
        device = torch.device(device)

        # Skip invalid configurations
445
        in_shape = (size, size)
446
        quantized_compute = quantization is not None
447
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
448
449

        # Construct operation
450
451
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
452
453
454
455
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)

        # Check forward and backward pass
        x = torch.zeros(
456
            in_shape,
457
458
459
460
            dtype=model_dtype,
            device=device,
            requires_grad=True,
        )
461
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
462
463
464
465
466
467
468
469
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
                y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == autocast_dtype
        assert x.grad.dtype == model_dtype
        assert op.weight.grad.dtype == model_dtype

        # Check forward and backward pass (swapped context order)
470
        if quantized_compute:
471
472
473
            x.grad = None
            op.weight.grad = None
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
474
                with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
475
476
477
478
479
480
                    y = op(x)
            y.backward(torch.zeros_like(y))
            assert y.dtype == autocast_dtype
            assert x.grad.dtype == model_dtype
            assert op.weight.grad.dtype == model_dtype

481
482
483
484
485
486
487
488
489
490
491
492
493

class TestBasicOps:
    """Tests for individual operations"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
494
    @pytest.mark.parametrize("quantization", _quantization_list)
495
496
497
    def test_identity(
        self,
        *,
498
        in_shape: Iterable[int] = (32, 32),
499
500
        dtype: torch.dtype,
        device: torch.device,
501
        quantization: Optional[str],
502
503
504
    ) -> None:

        # Skip invalid configurations
505
506
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
507
508
509
510

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
511
            quantization=quantization,
512
513
            test_dtype=dtype,
            test_device=device,
514
            test_is_quantized=with_quantization,
515
516
517
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
518
            quantization=quantization,
519
520
            test_dtype=dtype,
            test_device=device,
521
            test_is_quantized=with_quantization,
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Identity()
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Identity is exact
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

        # Make sure we are not trivially passing the test
        with pytest.raises(AssertionError):
            torch.testing.assert_close(y_test, -y_ref, **tols)
        with pytest.raises(AssertionError):
            torch.testing.assert_close(dx_test, -dx_ref, **tols)

    @pytest.mark.parametrize(
        "shapes",
        (
            ((1, 2, 3, 4), (2, 12)),
            ((5, 4, 3, 2), (-1, 6)),
            ((30,), (2, 3, -1)),
            ((6, 7), (3, -1, 7)),
        ),
    )
    @pytest.mark.parametrize("dtype", _dtypes)
557
    @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
558
559
560
561
562
    def test_reshape(
        self,
        *,
        shapes: tuple[Iterable[int], Iterable[int]],
        dtype: torch.dtype,
563
564
        device: torch.device = "cuda",
        memory_format: torch.memory_format = torch.contiguous_format,
565
        quantization: Optional[str],
566
567
568
569
570
571
    ) -> None:
        in_shape, out_shape = shapes

        # Skip invalid configurations
        if memory_format == torch.channels_last and len(in_shape) != 4:
            pytest.skip("torch.channels_last only supports 4D tensors")
572
573
        maybe_skip_quantization(quantization, device=device)
        with_quantization = quantization is not None
574
575
576
577

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
578
            quantization=quantization,
579
580
            test_dtype=dtype,
            test_device=device,
581
            test_is_quantized=with_quantization,
582
583
584
585
586
        )
        x_test = x_test.contiguous(memory_format=memory_format)
        x_test = x_test.detach().requires_grad_()
        dy_ref, dy_test = make_reference_and_test_tensors(
            x_ref.reshape(out_shape).size(),
587
            quantization=quantization,
588
589
            test_dtype=dtype,
            test_device=device,
590
            test_is_quantized=with_quantization,
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref.reshape(out_shape)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Reshape(out_shape)
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Reshape is exact
        y_test = y_test.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        dx_test = x_test.grad.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("size", (1, 7, 32))
619
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1)))
620
621
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", _devices)
622
    @pytest.mark.parametrize("quantization", _quantization_list)
623
624
625
626
627
628
629
    def test_bias(
        self,
        *,
        size: int,
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device,
630
        quantization: Optional[str],
631
632
633
634
635
636
    ) -> None:

        # Make input and bias shapes consistent
        in_shape = list(in_shape)[:-1] + [size]

        # Skip invalid configurations
637
638
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
639
640
641
642

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
643
            quantization=quantization,
644
645
            test_dtype=dtype,
            test_device=device,
646
            test_is_quantized=with_quantization,
647
648
649
650
651
652
653
654
        )
        b_ref, b_test = make_reference_and_test_tensors(
            size,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
655
            quantization=quantization,
656
657
            test_dtype=dtype,
            test_device=device,
658
            test_is_quantized=with_quantization,
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Bias(size, device=device, dtype=dtype)
        with torch.no_grad():
            op.bias.copy_(b_test)
            del b_test
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

683
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
684
685
    @pytest.mark.parametrize("cast_forward", (False, True))
    @pytest.mark.parametrize("cast_backward", (False, True))
686
    def test_quantize(
687
688
        self,
        *,
689
        in_shape: Iterable[int] = (32, 32),
Tim Moon's avatar
Tim Moon committed
690
        dtype: torch.dtype = torch.bfloat16,
691
        device: torch.device = "cuda",
692
        quantization: str,
Tim Moon's avatar
Tim Moon committed
693
694
        cast_forward: bool,
        cast_backward: bool,
695
    ) -> None:
696
697
698
        """Quantize"""

        # Skip invalid configurations
699
700
701
702
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, device=device)
        if quantization == "mxfp8":
            maybe_skip_quantization(quantization, dims=in_shape)
Tim Moon's avatar
Tim Moon committed
703
704
705
706

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
707
            quantization=quantization,
Tim Moon's avatar
Tim Moon committed
708
709
            test_dtype=dtype,
            test_device=device,
710
            requires_grad=True,
Tim Moon's avatar
Tim Moon committed
711
712
713
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
714
            quantization=quantization,
Tim Moon's avatar
Tim Moon committed
715
716
717
718
719
720
721
722
723
724
725
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
726
        recipe = make_recipe(quantization)
727
        with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
728
729
730
731
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check tensor types
732
733
734
        if with_quantization:
            assert isinstance(y_test, QuantizedTensor) == cast_forward
            assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
Tim Moon's avatar
Tim Moon committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749

        # Check values
        tols = dict(rtol=0, atol=0)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

    def _test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
750
751
752
753
754
755
756
        quantization: Optional[str] = None,
        quantized_compute: bool = False,
        quantized_input: bool = False,
        quantized_weight: bool = False,
        quantized_output: bool = False,
        quantized_grad_output: bool = False,
        quantized_grad_input: bool = False,
Tim Moon's avatar
Tim Moon committed
757
758
759
        accumulate_into_main_grad: bool = False,
    ) -> None:
        """Helper function for tests with GEMM"""
760
761
762
763
764
765
766

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
767
768
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        quantization_needed = any(
            (
                quantized_compute,
                quantized_input,
                quantized_weight,
                quantized_output,
                quantized_grad_output,
                quantized_grad_input,
            )
        )
        if quantization is None and quantization_needed:
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not quantization_needed:
            pytest.skip("Quantization scheme is not used")
        if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
            if quantized_output and not quantized_compute:
                pytest.skip("FP8 output is only supported with FP8 GEMMs")
            if quantized_grad_input and not quantized_compute:
                pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
788
789
790
791
        if quantization == "mxfp8" and quantized_output:
            pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
        if quantization == "mxfp8" and quantized_grad_input:
            pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
792
793
794
795

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
796
            quantization=quantization,
797
798
            test_dtype=dtype,
            test_device=device,
799
            test_is_quantized=quantized_input,
800
801
802
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
803
            quantization=quantization,
804
805
806
807
808
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
809
            quantization=quantization,
810
811
            test_dtype=dtype,
            test_device=device,
812
            test_is_quantized=quantized_grad_output,
813
814
815
816
817
818
819
820
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
821
822
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
823
824
825
826
827
828
829
830
831
832
833
            op = te_ops.BasicLinear(
                in_features,
                out_features,
                device=device,
                dtype=dtype,
                accumulate_into_main_grad=accumulate_into_main_grad,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
            op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
Tim Moon's avatar
Tim Moon committed
834
        forward = te_ops.Sequential(
835
            te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
Tim Moon's avatar
Tim Moon committed
836
            op,
837
            te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
Tim Moon's avatar
Tim Moon committed
838
        )
839
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
840
            y_test = forward(x_test)
841
842
843
844
845
846
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
847
848
        if quantized_compute or quantized_output or quantized_grad_input:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        if accumulate_into_main_grad:
            if op.weight.grad is not None:
                torch.testing.assert_close(
                    op.weight.grad,
                    torch.zeros_like(op.weight.grad),
                    rtol=0,
                    atol=0,
                )
            dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5
        else:
            dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(
                op.weight.main_grad,
                torch.full_like(op.weight.main_grad, 0.5),
                rtol=0,
                atol=0,
            )
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

874
875
    @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
Tim Moon's avatar
Tim Moon committed
876
    @pytest.mark.parametrize("dtype", _dtypes)
877
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
878
879
880
881
882
883
884
    @pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
    def test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
885
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
886
887
888
889
890
891
892
        accumulate_into_main_grad: bool,
    ) -> None:
        """GEMM"""
        self._test_basic_linear(
            weight_shape=weight_shape,
            in_shape=in_shape,
            dtype=dtype,
893
894
            quantization=quantization,
            quantized_compute=quantization is not None,
Tim Moon's avatar
Tim Moon committed
895
896
897
898
            accumulate_into_main_grad=accumulate_into_main_grad,
        )

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
899
    @pytest.mark.parametrize("quantization", _quantization_list)
900
901
902
903
904
905
906
    @pytest.mark.parametrize("quantized_compute", (False, True))
    @pytest.mark.parametrize("quantized_input", (False, True))
    @pytest.mark.parametrize("quantized_weight", (False, True))
    @pytest.mark.parametrize("quantized_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_input", (False, True))
    def test_basic_linear_quantized(
Tim Moon's avatar
Tim Moon committed
907
908
        self,
        *,
909
910
911
912
913
914
915
        quantization: str,
        quantized_compute: bool,
        quantized_input: bool,
        quantized_weight: bool,
        quantized_output: bool,
        quantized_grad_output: bool,
        quantized_grad_input: bool,
Tim Moon's avatar
Tim Moon committed
916
917
    ) -> None:
        """GEMM with FP8 inputs and outputs"""
918
919
        if quantization is None:
            pytest.skip("Skipping case without quantization")
Tim Moon's avatar
Tim Moon committed
920
921
        self._test_basic_linear(
            dtype=torch.bfloat16,
922
923
924
925
926
927
928
            quantization=quantization,
            quantized_compute=quantized_compute,
            quantized_input=quantized_input,
            quantized_weight=quantized_weight,
            quantized_output=quantized_output,
            quantized_grad_output=quantized_grad_output,
            quantized_grad_input=quantized_grad_input,
Tim Moon's avatar
Tim Moon committed
929
930
        )

931
    @pytest.mark.parametrize("bias", (False, True))
932
933
    @pytest.mark.parametrize("quantization", _quantization_list)
    @pytest.mark.parametrize("quantized_compute", (False, True))
934
    @pytest.mark.parametrize("quantized_weight", (False, True))
935
936
    @pytest.mark.parametrize("input_requires_grad", (False, True))
    @pytest.mark.parametrize("weight_requires_grad", (False, True))
937
938
939
940
    def test_linear(
        self,
        *,
        bias: bool,
941
942
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
943
944
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
945
        quantization: Optional[str],
946
        quantized_compute: bool,
947
        quantized_weight: bool,
948
949
        input_requires_grad: bool,
        weight_requires_grad: bool,
950
951
952
953
954
955
956
957
958
    ) -> None:
        """GEMM + bias"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
959
960
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
961
962
963
964
        if quantization is None and (quantized_compute or quantized_weight):
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not (quantized_compute or quantized_weight):
            pytest.skip("Quantization scheme is not used")
965
966
967
968

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
969
            quantization=quantization,
970
971
972
973
974
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
975
            quantization=quantization,
976
977
978
979
980
981
982
983
984
985
986
987
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
988
            quantization=quantization,
989
990
991
992
993
994
995
996
997
998
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
999
1000
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
            op = te_ops.Linear(
                in_features,
                out_features,
                bias=bias,
                device=device,
                dtype=dtype,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            if bias:
                op.bias.copy_(b_test)
            del w_test
            del b_test
1014
1015
            for param in op.parameters():
                param.requires_grad_(requires_grad=weight_requires_grad)
1016
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1017
            y_test = op(x_test)
1018
1019
        if input_requires_grad or weight_requires_grad:
            y_test.backward(dy_test)
1020
1021
1022
1023
1024

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1025
1026
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1027
1028
1029
1030

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
1031
1032
1033
1034
1035
1036
1037
1038
1039
        if input_requires_grad:
            dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        if weight_requires_grad:
            dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(dw_test, w_ref.grad, **tols)
            if bias:
                db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
                torch.testing.assert_close(db_test, b_ref.grad, **tols)
1040

1041
1042
    @pytest.mark.parametrize("weight_shape", ((7, 2), (32,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1043
1044
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1045
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
1046
1047
1048
1049
1050
1051
1052
1053
1054
    def test_layer_norm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1055
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1056
1057
1058
1059
1060
1061
1062
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1063
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=(w_ref + 1 if zero_centered_gamma else w_ref),
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
1111
1112
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1113
1114
        forward = te_ops.Sequential(
            op,
1115
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1116
        )
1117
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1118
1119
1120
1121
1122
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1123
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

    def test_layer_norm_autocast(
        self,
        *,
        weight_shape: Iterable[int] = (32,),
        in_shape: Iterable[int] = (32,),
        dtype: torch.dtype = torch.float16,
        autocast_dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        eps: float = 0.3,
    ) -> None:
        """Layer norm with PyTorch autocast"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=w_ref,
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
        with torch.autocast(device, dtype=autocast_dtype):
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        assert y_test.dtype == autocast_dtype
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dx_test, x_ref.grad, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype))
        torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype))

1211
1212
    @pytest.mark.parametrize("weight_shape", ((19,), (64,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1213
1214
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1215
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
    def test_rmsnorm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1225
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1226
1227
1228
1229
1230
1231
1232
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1233
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
        var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
        if zero_centered_gamma:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
        else:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.RMSNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
1273
1274
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1275
1276
        forward = te_ops.Sequential(
            op,
1277
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1278
        )
1279
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1280
1281
1282
1283
1284
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1285
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1286
1287
1288
1289
1290
1291
1292
1293
1294
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346

    @pytest.mark.parametrize("in_shape", ((32,), (6, 16, 64), (32, 64)))
    @pytest.mark.parametrize("dtype", _dtypes)
    def test_l2normalization(
        self,
        *,
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 1e-6,
    ) -> None:
        """L2 Normalization"""

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        # L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps)
        l2_norm_squared = x_ref.pow(2).sum(dim=-1, keepdim=True)
        rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
        y_ref = x_ref * rsqrt_norm
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.L2Normalization(
            eps=eps,
        )
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")

        torch.testing.assert_close(y_test, y_ref, **tols)
        # L2Norm backward pass requires slightly looser atol for bfloat16
        if dtype == torch.bfloat16:
            tols["atol"] = 2e-3
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
Tim Moon's avatar
Tim Moon committed
1347

1348
1349
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
1350
    @pytest.mark.parametrize("quantization", _quantization_list)
1351
1352
1353
    def test_add_in_place(
        self,
        *,
1354
        in_shape: Iterable[int] = (32, 32),
1355
1356
        dtype: torch.dtype,
        device: torch.device,
1357
        quantization: Optional[str],
1358
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1359
1360
1361
1362
1363
        """Add two tensors

        Join in compute graph.

        """
1364
1365

        # Skip invalid configurations
1366
1367
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1368
1369
1370
1371

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
1372
            quantization=quantization,
1373
1374
            test_dtype=dtype,
            test_device=device,
1375
            test_is_quantized=with_quantization,
1376
1377
1378
        )
        x2_ref, x2_test = make_reference_and_test_tensors(
            in_shape,
1379
            quantization=quantization,
1380
1381
            test_dtype=dtype,
            test_device=device,
1382
            test_is_quantized=with_quantization,
1383
1384
1385
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
1386
            quantization=quantization,
1387
1388
            test_dtype=dtype,
            test_device=device,
1389
            test_is_quantized=with_quantization,
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x2_ref.detach()
        y_ref += x1_ref
        dx1_ref = dy_ref
        dx2_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.AddInPlace()
        y_test = op(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
1406
        if with_quantization:
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
            tols = dtype_tols(x1_test._fp8_dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)

    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
1417
    @pytest.mark.parametrize("quantization", _quantization_list)
1418
1419
1420
    def test_make_extra_output(
        self,
        *,
1421
        in_shape: Iterable[int] = (32, 32),
1422
1423
        dtype: torch.dtype,
        device: torch.device,
1424
        quantization: Optional[str],
1425
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1426
1427
1428
1429
1430
        """Output tensor twice

        Split in compute graph.

        """
1431
1432

        # Skip invalid configurations
1433
1434
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1435
1436
1437
1438

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1439
            quantization=quantization,
1440
1441
            test_dtype=dtype,
            test_device=device,
1442
            test_is_quantized=with_quantization,
1443
1444
1445
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            in_shape,
1446
            quantization=quantization,
1447
1448
            test_dtype=dtype,
            test_device=device,
1449
            test_is_quantized=with_quantization,
1450
1451
1452
1453
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            in_shape,
1454
            quantization=quantization,
1455
1456
            test_dtype=dtype,
            test_device=device,
1457
            test_is_quantized=with_quantization,
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = x_ref
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operation
        op = te_ops.MakeExtraOutput()
        y1_test, y2_test = op(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check results
        tols = dtype_tols(dtype)
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0)
        torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1480
    @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
1481
    @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
1482
    @pytest.mark.parametrize("dtype", _dtypes)
1483
    @pytest.mark.parametrize("quantization", _quantization_list)
1484
    @pytest.mark.parametrize("cache_quantized_input", (False, True))
1485
1486
1487
1488
1489
1490
1491
    def test_activation(
        self,
        *,
        activation: str,
        out_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1492
        quantization: Optional[str],
1493
        cache_quantized_input: bool,
1494
1495
1496
1497
1498
1499
1500
1501
1502
    ) -> None:
        """Activation functions"""

        # Tensor dimensions
        in_shape = list(out_shape)
        if activation in ("geglu", "reglu", "swiglu"):
            in_shape[-1] *= 2

        # Skip invalid configurations
1503
1504
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1505
        if cache_quantized_input:
1506
            maybe_skip_quantization("fp8_current_scaling", device=device)
1507
1508
1509
1510

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1511
            quantization="fp8_current_scaling" if cache_quantized_input else None,
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref: torch.Tensor
        if activation == "gelu":
            y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
        elif activation == "relu":
            y_ref = torch.nn.functional.relu(x_ref)
        elif activation == "geglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
        elif activation == "reglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.relu(x1) * x2
        elif activation == "swiglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.silu(x1) * x2
        else:
            raise ValueError(f"Unexpected activation function ({activation})")
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1542
        recipe = make_recipe(quantization)
1543
1544
1545
1546
1547
1548
1549
1550
        make_op = dict(
            gelu=te_ops.GELU,
            relu=te_ops.ReLU,
            geglu=te_ops.GEGLU,
            reglu=te_ops.ReGLU,
            swiglu=te_ops.SwiGLU,
        )[activation]
        forward = te_ops.Sequential(
1551
            te_ops.Quantize(forward=False, backward=quantized_compute),
1552
            make_op(cache_quantized_input=cache_quantized_input),
1553
            te_ops.Quantize(forward=quantized_compute, backward=False),
1554
        )
1555
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1556
1557
1558
1559
1560
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1561
        if quantized_compute or cache_quantized_input:
1562
1563
1564
1565
1566
1567
1568
1569
1570
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("dtype", _dtypes)
1571
    @pytest.mark.parametrize("quantization", _quantization_list)
1572
1573
    @pytest.mark.parametrize("quantize_forward", (False, True))
    @pytest.mark.parametrize("quantize_backward", (False, True))
1574
1575
1576
    def test_swiglu(
        self,
        *,
1577
        out_shape: Iterable[int] = (32, 32),
1578
1579
        dtype: torch.dtype,
        device: torch.device = "cuda",
1580
1581
1582
        quantization: Optional[str],
        quantize_forward: bool,
        quantize_backward: bool,
1583
1584
1585
1586
1587
1588
1589
    ):

        # Tensor dimensions
        in_shape = list(out_shape)
        in_shape[-1] *= 2

        # Skip invalid configurations
1590
1591
1592
1593
        quantized_compute = quantization is not None
        if not quantized_compute and (quantize_forward or quantize_backward):
            pytest.skip("Quantization scheme has not been provided")
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        x1, x2 = x_ref.chunk(2, dim=-1)
        y_ref = torch.nn.functional.silu(x1) * x2
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1614
        recipe = make_recipe(quantization)
1615
        forward = te_ops.Sequential(
1616
            te_ops.Quantize(forward=False, backward=quantize_backward),
1617
            te_ops.SwiGLU(),
1618
            te_ops.Quantize(forward=quantize_forward, backward=False),
1619
        )
1620
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1621
1622
1623
1624
1625
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1626
        if quantized_compute:
1627
1628
1629
1630
1631
1632
1633
1634
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645

class TestFusedOps:
    """Tests for fused operations"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

1646
1647
    @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
1648
    @pytest.mark.parametrize("dtype", _dtypes)
1649
    @pytest.mark.parametrize("quantization", _quantization_list)
1650
    @pytest.mark.parametrize("quantized_weight", (False, True))
1651
    def test_forward_linear_bias_activation(
1652
1653
1654
1655
1656
1657
1658
        self,
        *,
        bias: bool = True,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1659
1660
        quantization: Optional[str],
        quantized_weight: bool,
1661
    ) -> None:
1662
        """Forward GEMM + bias + activation"""
1663
1664
1665
1666
1667
1668
1669

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1670
1671
1672
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
1673
1674
1675
1676
1677
1678
1679
1680
        if dtype not in (torch.float16, torch.bfloat16):
            pytest.skip(
                "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output"
            )

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1681
            quantization=quantization,
1682
1683
1684
1685
1686
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1687
            quantization=quantization,
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1700
            quantization=quantization,
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1711
1712
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1728
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
            y_test = model(x_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasActivation)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1741
1742
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

1755
1756
    @pytest.mark.parametrize("bias", (False, True))
    @pytest.mark.parametrize("dtype", _dtypes)
1757
    @pytest.mark.parametrize("quantization", _quantization_list)
1758
1759
1760
1761
    def test_forward_linear_bias_add(
        self,
        *,
        bias: bool,
1762
1763
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1764
1765
        dtype: torch.dtype,
        device: torch.device = "cuda",
1766
1767
        quantization: Optional[str],
        quantized_weight: bool = False,
1768
1769
1770
1771
1772
1773
1774
1775
1776
    ) -> None:
        """Forward GEMM + bias + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1777
1778
1779
1780
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
1781
1782
1783
1784
1785
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
1786
            quantization=quantization,
1787
1788
1789
1790
1791
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1792
            quantization=quantization,
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        x2_ref, x2_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1810
            quantization=quantization,
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1821
1822
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
                te_ops.AddInPlace(),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1839
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
            y_test = model(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1852
1853
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
        torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

    @pytest.mark.parametrize("dtype", _dtypes)
1869
    @pytest.mark.parametrize("quantization", _quantization_list)
1870
1871
1872
    def test_backward_linear_add(
        self,
        *,
1873
1874
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1875
1876
        dtype: torch.dtype,
        device: torch.device = "cuda",
1877
1878
        quantization: Optional[str],
        quantized_weight: bool = False,
1879
1880
1881
1882
1883
1884
1885
1886
1887
    ) -> None:
        """Backward dgrad GEMM + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1888
1889
1890
1891
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
1892
1893
1894
1895
1896
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1897
            quantization=quantization,
1898
1899
1900
1901
1902
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1903
            quantization=quantization,
1904
1905
1906
1907
1908
            test_dtype=dtype,
            test_device=device,
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            out_shape,
1909
            quantization=quantization,
1910
1911
1912
1913
1914
1915
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            out_shape,
1916
            quantization=quantization,
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = torch.nn.functional.linear(x_ref, w_ref)
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operations
1928
1929
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight):
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
            model = te_ops.Sequential(
                te_ops.MakeExtraOutput(),
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=False,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[1].weight.copy_(w_test)
            del w_test
1943
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
            y1_test, y2_test = model(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check that backward operations have been fused
        backward_ops = model._module_groups[0]._backward_ops
        assert len(backward_ops) == 1
        assert isinstance(backward_ops[0][0], BackwardLinearAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1956
1957
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967

        # Check results
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, **tols)
        torch.testing.assert_close(y2_test, y2_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979


class TestCheckpointing:
    """Tests for checkpointing"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

1980
    @pytest.mark.parametrize("quantization", _quantization_list)
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
    @pytest.mark.parametrize("quantized_weight", (False, True))
    def test_linear(
        self,
        *,
        pre_checkpoint_steps: int = 2,
        post_checkpoint_steps: int = 2,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        quantization: Optional[str],
        quantized_weight: bool,
    ) -> None:
        """Check checkpointing with linear op"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)

        # Construct model
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            model_save = te_ops.Sequential(
                te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
            )
        optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25)

        # Warmup training steps
        for _ in range(pre_checkpoint_steps):
            x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
            dy = torch.randn(out_shape, dtype=dtype, device=device)
            optim_save.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_save(x)
            y.backward(dy)
            optim_save.step()

        # Save checkpoint
        byte_stream = io.BytesIO()
        torch.save(
            {"model": model_save.state_dict(), "optim": optim_save.state_dict()},
            byte_stream,
        )
        checkpoint_bytes = byte_stream.getvalue()
        del byte_stream

        # Synthetic data for evaluation
        xs_save = [
            torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
            for _ in range(post_checkpoint_steps)
        ]
        with torch.no_grad():
            xs_load = [x.clone().requires_grad_() for x in xs_save]
        dys = [
            torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps)
        ]

        # Training steps with original model
        ys_save = []
        for i in range(post_checkpoint_steps):
            optim_save.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_save(xs_save[i])
            y.backward(dys[i])
            optim_save.step()
            ys_save.append(y)

        # Load checkpoint
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            model_load = te_ops.Sequential(
                te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
            )
        optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
        state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
        model_load.load_state_dict(state_dict["model"])
        optim_load.load_state_dict(state_dict["optim"])

        # Training steps with loaded model
        ys_load = []
        for i in range(post_checkpoint_steps):
            optim_load.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_load(xs_load[i])
            y.backward(dys[i])
            optim_load.step()
            ys_load.append(y)

        # Check that original and loaded model match exactly
        tols = {"rtol": 0, "atol": 0}
        for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
            torch.testing.assert_close(param_load, param_save, **tols)
            torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
        for y_load, y_save in zip(ys_load, ys_save):
            torch.testing.assert_close(y_load, y_save, **tols)
        for x_load, x_save in zip(xs_load, xs_save):
            torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188


class TestSequentialModules:
    """Test for larger Sequentials with modules commonly used together"""

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.parametrize("bias", (False, True))
    @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
    @pytest.mark.parametrize("quantized_compute", (False, True))
    @pytest.mark.parametrize("quantized_weight", (False, True))
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("quantization", _quantization_list)
    def test_layernorm_mlp(
        self,
        *,
        bias: bool,
        normalization: str,
        quantized_compute: bool,
        quantized_weight: bool,
        dtype: torch.dtype,
        quantization: Optional[str],
        device: torch.device = "cuda",
        hidden_size: int = 32,
        sequence_length: int = 512,
        batch_size: int = 4,
        ffn_hidden_size: int = 64,
        layernorm_epsilon: float = 1e-5,
    ) -> None:
        """
        LayerNorm/RMSNorm + Linear + GELU + Linear

        Note that this test checks only if the module runs
        as when chaining multiple modules it is hard to validate
        numerical accuracy.
        """

        # Make input shape
        in_shape = (sequence_length, batch_size, hidden_size)
        ffn_shape = in_shape[:-1] + (ffn_hidden_size,)

        # Skip invalid configurations
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
        quantization_needed = quantized_compute or quantized_weight
        if quantization is None and quantization_needed:
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not quantization_needed:
            pytest.skip("Quantization scheme is not used")

        # Random data
        _, x_test = make_reference_and_test_tensors(
            in_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
        )
        _, dy_test = make_reference_and_test_tensors(
            in_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Implementation with fusible operations
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            if normalization == "LayerNorm":
                norm = te_ops.LayerNorm(
                    hidden_size,
                    eps=layernorm_epsilon,
                    device=device,
                    dtype=dtype,
                )
            else:
                norm = te_ops.RMSNorm(
                    hidden_size,
                    eps=layernorm_epsilon,
                    device=device,
                    dtype=dtype,
                )
            ffn1 = te_ops.Linear(
                hidden_size,
                ffn_hidden_size,
                bias=bias,
                device=device,
                dtype=dtype,
            )
            act = te_ops.GELU()
            ffn2 = te_ops.Linear(
                ffn_hidden_size,
                hidden_size,
                bias=bias,
                device=device,
                dtype=dtype,
            )
        forward = te_ops.Sequential(norm, ffn1, act, ffn2)
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
            y_test = forward(x_test)
        y_test.backward(dy_test)