test_fusible_ops.py 94.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

from __future__ import annotations

7
from collections.abc import Iterable
8
import io
9
import math
10
11
import pathlib
import sys
12
from typing import Optional
13
14
15

import pytest
import torch
yuguo's avatar
yuguo committed
16
from torch.utils.cpp_extension import IS_HIP_EXTENSION
17
18

import transformer_engine
Tim Moon's avatar
Tim Moon committed
19
import transformer_engine.common.recipe
20
21
22
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
23
from transformer_engine.pytorch.ops.fused import (
Jan Bielak's avatar
Jan Bielak committed
24
    BackwardActivationBias,
25
    BackwardLinearAdd,
Jan Bielak's avatar
Jan Bielak committed
26
    BackwardLinearScale,
27
    ForwardLinearBiasActivation,
28
    ForwardLinearBiasAdd,
Jan Bielak's avatar
Jan Bielak committed
29
    ForwardLinearScaleAdd,
30
)
31
from transformer_engine.pytorch.tensor import QuantizedTensor
32
33
34
35
36
37
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8Tensor,
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
38
39
40
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex

41
42
43
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
44
from utils import dtype_tols, make_recipe, reset_rng_states
45

yuguo's avatar
yuguo committed
46
47
48
49
50
51
52
53
if IS_HIP_EXTENSION:
    import os
    from functools import cache
    @cache
    def use_hipblaslt() -> bool:
        return (os.getenv("NVTE_USE_HIPBLASLT") is not None
                or os.getenv("NVTE_USE_ROCBLAS") is None )

54
55
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
56
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
57
58
59
60
61
62
63
64
65

# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    _dtypes.append(torch.bfloat16)

# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]

66
67
68
69
70
71
72
# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
    _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
    _quantization_list.append("mxfp8")

73

74
75
76
77
78
79
def maybe_skip_quantization(
    quantization: Optional[str],
    *,
    dims: Optional[Iterable[int] | int] = None,
    device: Optional[torch.device | str] = None,
) -> None:
80
    """Skip test case if a quantization scheme is not supported"""
81
82
83
84
85
86

    # Don't skip if there is no quantization
    if quantization is None:
        return

    # Check if quantization scheme is supported
87
    if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
88
89
90
91
92
93
94
        pytest.skip(reason_for_no_fp8)
    if quantization == "mxfp8" and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)

    if dims is not None:
        if not isinstance(dims, Iterable):
            dims = (dims,)
95
        if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
96
97
98
99
100
101
102
103
104
105
106
            if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
                pytest.skip("FP8 GEMMs require dims that are divisible by 16")
        elif quantization == "mxfp8":
            if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
                pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")

    # Check if device is supported
    if device is not None and torch.device(device).type != "cuda":
        pytest.skip("Quantization is only supported on CUDA devices")


107
108
109
@torch.no_grad()
def make_reference_and_test_tensors(
    shape: int | Iterable[int],
110
    quantization: Optional[str] = None,
111
112
113
114
    ref_dtype: torch.dtype = torch.float64,
    ref_device: torch.device = "cpu",
    test_dtype: torch.dtype = torch.float32,
    test_device: torch.device = "cuda",
115
    test_is_quantized: bool = False,
116
117
118
119
120
121
122
123
    requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct tensors with the same values

    The reference tensor is intended for use in plain PyTorch
    operations in high precision. The test tensor is intended for use
    in Transformer Engine operations.

124
125
126
    If a quantization scheme is provided, the tensor values are
    quantized so that they are representable.

127
    """
128
129

    # Random reference tensor
130
    ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
131
132

    # Construct test tensor from reference tensor
133
    test = ref.to(device=test_device, dtype=test_dtype)
134
135
136
137
138
139
    if quantization is None:
        if test_is_quantized:
            raise ValueError("Quantization scheme not provided")
        if test.data_ptr() == ref.data_ptr():
            test = test.clone()
    elif quantization in ("fp8", "fp8_delayed_scaling"):
140
141
142
143
144
145
        quantizer = Float8Quantizer(
            scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
            amax=torch.zeros(1, dtype=torch.float32, device=test_device),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        test = quantizer(test)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    elif quantization == "fp8_current_scaling":
        quantizer = Float8CurrentScalingQuantizer(
            fp8_dtype=tex.DType.kFloat8E4M3,
            device=test_device,
        )
        test = quantizer(test)
    elif quantization == "mxfp8":
        test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
    else:
        raise ValueError(f"Unsupported quantization scheme ({quantization})")
    if isinstance(test, QuantizedTensor) and not test_is_quantized:
        test = test.dequantize()

    # Make sure reference and test tensors match each other
160
    ref.copy_(test)
161

162
163
164
165
166
    ref.requires_grad_(requires_grad)
    test.requires_grad_(requires_grad)
    return ref, test


167
class TestSequentialContainer:
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    """Tests for sequential container"""

    def test_modules(self) -> None:
        """Check that list of modules can be manipulated as expected"""

        # Construct sequential container
        modules = [
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
        ]
        model = te_ops.Sequential(*modules)

        # Length
        assert len(model) == len(modules)

        # Iterator
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Index by int
        for i, module in enumerate(modules):
            assert model[i] is module
            assert model[i - len(modules)] is module

        # Index by slice
        model_subset = model[1:-1]
        modules_subset = modules[1:-1]
        assert isinstance(model_subset, te_ops.Sequential)
        for module1, module2 in zip(model_subset, modules_subset):
            assert module1 is module2

        # Set element
        new_module = torch.nn.Identity()
        idx = 1
        modules[idx] = new_module
        model[idx] = new_module
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Delete element
        idx = 1
        del modules[idx]
        del model[idx]
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Append
        new_module = torch.nn.Identity()
        modules.append(new_module)
        model.append(new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Extend
        new_modules = [te_ops.Identity(), te_ops.Identity()]
        modules.extend(new_modules)
        model.extend(new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Insert
        new_module = te_ops.Identity()
        idx = 2
        modules.insert(idx, new_module)
        model.insert(idx, new_module)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Pop
        idx = 2
        assert model.pop(idx) is modules.pop(idx)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

        # Out-of-place add
        new_modules = [torch.nn.Identity(), te_ops.Identity()]
        added_modules = modules + new_modules
        added_model = model + te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2
        for module1, module2 in zip(added_model, added_modules):
            assert module1 is module2

        # In-place add
        new_modules = [te_ops.Identity(), torch.nn.Identity()]
        modules += new_modules
        model += te_ops.Sequential(*new_modules)
        for module1, module2 in zip(model, modules):
            assert module1 is module2

    def test_module_groups(self) -> None:
        """Check that modules are grouped together correctly"""
        model = te_ops.Sequential(
            te_ops.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            torch.nn.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
            te_ops.Identity(),
        )
        model(torch.zeros(1))
        assert len(model._module_groups) == 6

Jan Bielak's avatar
Jan Bielak committed
276
277
278
279
280
281
282
283
    def test_extra_tensors(self, size: int = 16) -> None:
        """Check that extra inputs are distributed properly between module groups
        and that extra outputs are properly collected"""

        # Construct sequential container
        bias = te_ops.Bias(size=size, device="cpu")
        with torch.no_grad():
            bias.bias.copy_(torch.rand((size,)))
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        model = te_ops.Sequential(  #                 | Inputs  | Outputs
            torch.nn.Identity(),  #                   | x1      | x1
            te_ops.MakeExtraOutput(in_place=True),  # | x1      | x1 [x1]
            bias,  #                                  | x1      | h1 (= x1 + b)
            te_ops.MakeExtraOutput(in_place=True),  # | h1      | h1 [h1]
            te_ops.AddExtraInput(in_place=True),  #   | h1 [x2] | x2 (= x2 + h1)
            te_ops.MakeExtraOutput(in_place=True),  # | x2      | x2 [x2]
            torch.nn.Identity(),  #                   | x2      | x2
            bias,  #                                  | x2      | h2 (= x2 + b)
            te_ops.AddExtraInput(in_place=True),  #   | h2 [x3] | x3 (= x3 + h2)
            te_ops.MakeExtraOutput(in_place=True),  # | x3      | x3 [x3]
            te_ops.AddExtraInput(in_place=True),  #   | x3 [x4] | x4 (= x4 + x3)
            torch.nn.Identity(),  #                   | x4      | x4
            te_ops.Identity(),  #                     | x4      | x4
            te_ops.MakeExtraOutput(in_place=True),  # | x4      | x4 [x4]
            te_ops.Identity(),  #                     | x4      | x4
Jan Bielak's avatar
Jan Bielak committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        )

        # Create input tensors
        x1 = torch.rand((size,))
        x2 = torch.rand((size,))
        x3 = torch.rand((size,))
        x4 = torch.rand((size,))

        # Save original input tensor values
        x1_orig = x1.clone()
        x2_orig = x2.clone()
        x3_orig = x3.clone()
        x4_orig = x4.clone()

        # Run forward
        ys = model(x1, x2, x3, x4)

        # Check whether outputs match (x4, x1, h1, x2, x3, x4)
        assert len(ys) == 6
        assert ys[0].data_ptr() == x4.data_ptr()
        assert ys[1].data_ptr() == x1.data_ptr()
        assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)]
        assert ys[3].data_ptr() == x2.data_ptr()
        assert ys[4].data_ptr() == x3.data_ptr()
        assert ys[5].data_ptr() == x4.data_ptr()

        # Check whether tensors have correct values
        b = bias.bias
        h1 = ys[2]
        torch.testing.assert_close(x1, x1_orig)
        torch.testing.assert_close(h1, x1_orig + b)
        torch.testing.assert_close(x2, x2_orig + h1)
        torch.testing.assert_close(x3, x3_orig + x2 + b)
        torch.testing.assert_close(x4, x4_orig + x3)

335
336
337
338
339
340

class TestFuser:
    """Tests for operation fusion infrastructure"""

    @staticmethod
    def setup_class(cls) -> None:
341
        reset_rng_states()
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
    def test_fp8_scale_update(
        self,
        size: int = 16,
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
    ):
        """Test FP8 scaling factors with delayed scaling recipe"""

        # FP8 recipe
        margin = 2
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=8,
            amax_compute_algo="max",
        )

        # Construct model
363
        with te.fp8_model_init(recipe=recipe):
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
            model = te_ops.basic.BasicLinear(
                size,
                size,
                device=device,
                dtype=dtype,
            )

        # Training steps
        w_vals = [2, 5, 3, 11]
        x_vals = [7, 3, 5]
        dy_vals = [1, 2, 1]
        with torch.no_grad():
            model.weight.fill_(w_vals[0])
        for step in range(3):

            # Data tensors
            x = torch.full(
                (size, size),
                x_vals[step],
                dtype=dtype,
                device=device,
                requires_grad=True,
            )
            dy = torch.full(
                (size, size),
                dy_vals[step],
                dtype=dtype,
                device=device,
            )

            # Training step
            with te.fp8_autocast(fp8_recipe=recipe):
                y = model(x)
            y.backward(dy)
            with torch.no_grad():
                model.weight.fill_(w_vals[step + 1])

            # Check that output tensors match expected
            tols = dict(rtol=0, atol=0)
            y_val_ref = w_vals[step] * x_vals[step] * size
            dx_val_ref = w_vals[step] * dy_vals[step] * size
            torch.testing.assert_close(
                y,
                torch.full_like(y, y_val_ref),
                **dtype_tols(tex.DType.kFloat8E4M3),
            )
            torch.testing.assert_close(
                x.grad,
                torch.full_like(x.grad, dx_val_ref),
                **dtype_tols(tex.DType.kFloat8E5M2),
            )

            # Check that scaling factors match expected
417
            w_amax_ref = max(w_vals[: step + 1])
418
419
420
421
422
            x_amax_ref = max(x_vals[: step + 1])
            dy_amax_ref = max(dy_vals[: step + 1])
            w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
            x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin)
            dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin)
423
424
425
            w_scale = model.get_quantizer("forward", 1).scale
            x_scale = model.get_quantizer("forward", 0).scale
            dy_scale = model.get_quantizer("backward", 0).scale
426
427
428
429
            torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref))
            torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
            torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))

430
431
    @pytest.mark.parametrize("init_dtype", _dtypes)
    @pytest.mark.parametrize("final_dtype", _dtypes)
432
    @pytest.mark.parametrize("quantization", _quantization_list)
433
434
435
    def test_dtype_cast(
        self,
        *,
436
        size: int = 32,
437
438
439
        init_dtype: torch.dtype,
        final_dtype: torch.dtype,
        device: torch.device = "cuda",
440
        quantization: Optional[str],
441
442
443
444
    ) -> None:
        """Check dtype cast functions"""

        # Skip invalid configurations
445
        in_shape = (size, size)
446
        with_quantization = quantization is not None
447
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
448
449
450
451
452
453
454
455
456

        # Random data
        dtype = torch.float32
        if torch.float16 in (init_dtype, final_dtype):
            dtype = torch.float16
        if torch.bfloat16 in (init_dtype, final_dtype):
            dtype = torch.bfloat16
        w_ref, w_test = make_reference_and_test_tensors(
            (size, size),
457
            quantization=quantization,
458
459
460
461
462
            test_dtype=dtype,
            test_device=device,
        )

        # Construct operation
463
        with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
464
465
466
467
468
469
470
471
472
473
474
475
476
477
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test

        # Cast operation dtype
        if final_dtype == torch.float32:
            op.float()
        elif final_dtype == torch.float16:
            op.half()
        elif final_dtype == torch.bfloat16:
            op.bfloat16()

        # Check weights
478
        assert isinstance(op.weight, QuantizedTensor) == with_quantization
479
480
        assert op.weight.dtype == final_dtype
        w_test = op.weight.to(dtype=torch.float64, device="cpu")
481
        torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype))
482
483
484

        # Check forward and backward pass
        x = torch.zeros(
485
            in_shape,
486
487
488
489
490
491
492
493
494
495
496
497
            dtype=init_dtype,
            device=device,
            requires_grad=True,
        )
        y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == final_dtype
        assert x.grad.dtype == init_dtype
        assert op.weight.grad.dtype == final_dtype

    @pytest.mark.parametrize("model_dtype", _dtypes)
    @pytest.mark.parametrize("autocast_dtype", _dtypes)
498
    @pytest.mark.parametrize("quantization", _quantization_list)
499
500
501
    def test_pyt_autocast(
        self,
        *,
502
        size: int = 32,
503
504
505
        model_dtype: torch.dtype,
        autocast_dtype: torch.dtype,
        device: torch.device = "cuda",
506
507
        quantization: Optional[str],
        quantized_weights: bool = False,
508
509
510
511
512
    ) -> None:
        """Test with PyTorch autocast"""
        device = torch.device(device)

        # Skip invalid configurations
513
        in_shape = (size, size)
514
        quantized_compute = quantization is not None
515
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
516
517

        # Construct operation
518
519
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weights, recipe=recipe):
520
521
522
523
            op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)

        # Check forward and backward pass
        x = torch.zeros(
524
            in_shape,
525
526
527
528
            dtype=model_dtype,
            device=device,
            requires_grad=True,
        )
529
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
530
531
532
533
534
535
536
537
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
                y = op(x)
        y.backward(torch.zeros_like(y))
        assert y.dtype == autocast_dtype
        assert x.grad.dtype == model_dtype
        assert op.weight.grad.dtype == model_dtype

        # Check forward and backward pass (swapped context order)
538
        if quantized_compute:
539
540
541
            x.grad = None
            op.weight.grad = None
            with torch.autocast(device_type=device.type, dtype=autocast_dtype):
542
                with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
543
544
545
546
547
548
                    y = op(x)
            y.backward(torch.zeros_like(y))
            assert y.dtype == autocast_dtype
            assert x.grad.dtype == model_dtype
            assert op.weight.grad.dtype == model_dtype

549
550
551
552
553
554

class TestBasicOps:
    """Tests for individual operations"""

    @staticmethod
    def setup_class(cls) -> None:
555
        reset_rng_states()
556
557
558

    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
559
    @pytest.mark.parametrize("quantization", _quantization_list)
560
561
562
    def test_identity(
        self,
        *,
563
        in_shape: Iterable[int] = (32, 32),
564
565
        dtype: torch.dtype,
        device: torch.device,
566
        quantization: Optional[str],
567
568
569
    ) -> None:

        # Skip invalid configurations
570
571
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
572
573
574
575

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
576
            quantization=quantization,
577
578
            test_dtype=dtype,
            test_device=device,
579
            test_is_quantized=with_quantization,
580
581
582
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
583
            quantization=quantization,
584
585
            test_dtype=dtype,
            test_device=device,
586
            test_is_quantized=with_quantization,
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Identity()
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Identity is exact
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

        # Make sure we are not trivially passing the test
        with pytest.raises(AssertionError):
            torch.testing.assert_close(y_test, -y_ref, **tols)
        with pytest.raises(AssertionError):
            torch.testing.assert_close(dx_test, -dx_ref, **tols)

    @pytest.mark.parametrize(
        "shapes",
        (
            ((1, 2, 3, 4), (2, 12)),
            ((5, 4, 3, 2), (-1, 6)),
            ((30,), (2, 3, -1)),
            ((6, 7), (3, -1, 7)),
        ),
    )
    @pytest.mark.parametrize("dtype", _dtypes)
622
    @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
623
624
625
626
627
    def test_reshape(
        self,
        *,
        shapes: tuple[Iterable[int], Iterable[int]],
        dtype: torch.dtype,
628
629
        device: torch.device = "cuda",
        memory_format: torch.memory_format = torch.contiguous_format,
630
        quantization: Optional[str],
631
632
633
634
635
636
    ) -> None:
        in_shape, out_shape = shapes

        # Skip invalid configurations
        if memory_format == torch.channels_last and len(in_shape) != 4:
            pytest.skip("torch.channels_last only supports 4D tensors")
637
638
        maybe_skip_quantization(quantization, device=device)
        with_quantization = quantization is not None
639
640
641
642

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
643
            quantization=quantization,
644
645
            test_dtype=dtype,
            test_device=device,
646
            test_is_quantized=with_quantization,
647
648
649
650
651
        )
        x_test = x_test.contiguous(memory_format=memory_format)
        x_test = x_test.detach().requires_grad_()
        dy_ref, dy_test = make_reference_and_test_tensors(
            x_ref.reshape(out_shape).size(),
652
            quantization=quantization,
653
654
            test_dtype=dtype,
            test_device=device,
655
            test_is_quantized=with_quantization,
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref.reshape(out_shape)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Reshape(out_shape)
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dict(rtol=0, atol=0)  # Reshape is exact
        y_test = y_test.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        dx_test = x_test.grad.to(
            dtype=torch.float64,
            device="cpu",
            memory_format=torch.contiguous_format,
        )
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("size", (1, 7, 32))
684
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1)))
685
686
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", _devices)
687
    @pytest.mark.parametrize("quantization", _quantization_list)
688
689
690
691
692
693
694
    def test_bias(
        self,
        *,
        size: int,
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device,
695
        quantization: Optional[str],
696
697
698
699
700
701
    ) -> None:

        # Make input and bias shapes consistent
        in_shape = list(in_shape)[:-1] + [size]

        # Skip invalid configurations
702
703
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
704
705
706
707

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
708
            quantization=quantization,
709
710
            test_dtype=dtype,
            test_device=device,
711
            test_is_quantized=with_quantization,
712
713
714
715
716
717
718
719
        )
        b_ref, b_test = make_reference_and_test_tensors(
            size,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
720
            quantization=quantization,
721
722
            test_dtype=dtype,
            test_device=device,
723
            test_is_quantized=with_quantization,
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.Bias(size, device=device, dtype=dtype)
        with torch.no_grad():
            op.bias.copy_(b_test)
            del b_test
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

748
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
749
750
    @pytest.mark.parametrize("cast_forward", (False, True))
    @pytest.mark.parametrize("cast_backward", (False, True))
751
    def test_quantize(
752
753
        self,
        *,
754
        in_shape: Iterable[int] = (32, 32),
Tim Moon's avatar
Tim Moon committed
755
        dtype: torch.dtype = torch.bfloat16,
756
        device: torch.device = "cuda",
757
        quantization: str,
Tim Moon's avatar
Tim Moon committed
758
759
        cast_forward: bool,
        cast_backward: bool,
760
    ) -> None:
761
762
763
        """Quantize"""

        # Skip invalid configurations
764
765
766
767
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, device=device)
        if quantization == "mxfp8":
            maybe_skip_quantization(quantization, dims=in_shape)
Tim Moon's avatar
Tim Moon committed
768
769
770
771

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
772
            quantization=quantization,
Tim Moon's avatar
Tim Moon committed
773
774
            test_dtype=dtype,
            test_device=device,
775
            requires_grad=True,
Tim Moon's avatar
Tim Moon committed
776
777
778
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
779
            quantization=quantization,
Tim Moon's avatar
Tim Moon committed
780
781
782
783
784
785
786
787
788
789
790
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref
        dx_ref = dy_ref

        # Implementation with fusible operation
        op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
791
        recipe = make_recipe(quantization)
792
        with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
793
794
795
796
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check tensor types
797
798
799
        if with_quantization:
            assert isinstance(y_test, QuantizedTensor) == cast_forward
            assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
Tim Moon's avatar
Tim Moon committed
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814

        # Check values
        tols = dict(rtol=0, atol=0)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, dx_ref, **tols)

    def _test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
815
816
817
818
819
820
821
        quantization: Optional[str] = None,
        quantized_compute: bool = False,
        quantized_input: bool = False,
        quantized_weight: bool = False,
        quantized_output: bool = False,
        quantized_grad_output: bool = False,
        quantized_grad_input: bool = False,
Tim Moon's avatar
Tim Moon committed
822
823
824
        accumulate_into_main_grad: bool = False,
    ) -> None:
        """Helper function for tests with GEMM"""
825
826
827
828
829
830
831

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
832
833
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
        quantization_needed = any(
            (
                quantized_compute,
                quantized_input,
                quantized_weight,
                quantized_output,
                quantized_grad_output,
                quantized_grad_input,
            )
        )
        if quantization is None and quantization_needed:
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not quantization_needed:
            pytest.skip("Quantization scheme is not used")
        if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
            if quantized_output and not quantized_compute:
                pytest.skip("FP8 output is only supported with FP8 GEMMs")
            if quantized_grad_input and not quantized_compute:
                pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
853
854
855
        if quantization not in (None, "fp8"):
            if quantized_output or quantized_grad_input:
                pytest.skip("Recipe does not support quantized GEMM output")
yuguo's avatar
yuguo committed
856
857
858
        if ( IS_HIP_EXTENSION and not use_hipblaslt() and
            accumulate_into_main_grad and dtype != torch.float32 and not quantized_compute):
            pytest.skip("Parameters combination is not supported by ROCBLAS")
859
860
861
862

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
863
            quantization=quantization,
864
865
            test_dtype=dtype,
            test_device=device,
866
            test_is_quantized=quantized_input,
867
868
869
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
870
            quantization=quantization,
871
872
873
874
875
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
876
            quantization=quantization,
877
878
            test_dtype=dtype,
            test_device=device,
879
            test_is_quantized=quantized_grad_output,
880
881
882
883
884
885
886
887
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
888
889
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
890
891
892
893
894
895
896
897
898
899
900
            op = te_ops.BasicLinear(
                in_features,
                out_features,
                device=device,
                dtype=dtype,
                accumulate_into_main_grad=accumulate_into_main_grad,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
            op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
Tim Moon's avatar
Tim Moon committed
901
        forward = te_ops.Sequential(
902
            te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
Tim Moon's avatar
Tim Moon committed
903
            op,
904
            te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
Tim Moon's avatar
Tim Moon committed
905
        )
906
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
907
            y_test = forward(x_test)
908
909
910
911
912
913
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
914
915
        if quantized_compute or quantized_output or quantized_grad_input:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        if accumulate_into_main_grad:
            if op.weight.grad is not None:
                torch.testing.assert_close(
                    op.weight.grad,
                    torch.zeros_like(op.weight.grad),
                    rtol=0,
                    atol=0,
                )
            dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5
        else:
            dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(
                op.weight.main_grad,
                torch.full_like(op.weight.main_grad, 0.5),
                rtol=0,
                atol=0,
            )
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

941
942
    @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
Tim Moon's avatar
Tim Moon committed
943
    @pytest.mark.parametrize("dtype", _dtypes)
944
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
945
946
947
948
949
950
951
    @pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
    def test_basic_linear(
        self,
        *,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
952
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
953
954
955
956
957
958
959
        accumulate_into_main_grad: bool,
    ) -> None:
        """GEMM"""
        self._test_basic_linear(
            weight_shape=weight_shape,
            in_shape=in_shape,
            dtype=dtype,
960
961
            quantization=quantization,
            quantized_compute=quantization is not None,
Tim Moon's avatar
Tim Moon committed
962
963
964
965
            accumulate_into_main_grad=accumulate_into_main_grad,
        )

    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
966
    @pytest.mark.parametrize("quantization", _quantization_list)
967
968
969
970
971
972
973
    @pytest.mark.parametrize("quantized_compute", (False, True))
    @pytest.mark.parametrize("quantized_input", (False, True))
    @pytest.mark.parametrize("quantized_weight", (False, True))
    @pytest.mark.parametrize("quantized_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_output", (False, True))
    @pytest.mark.parametrize("quantized_grad_input", (False, True))
    def test_basic_linear_quantized(
Tim Moon's avatar
Tim Moon committed
974
975
        self,
        *,
976
977
978
979
980
981
982
        quantization: str,
        quantized_compute: bool,
        quantized_input: bool,
        quantized_weight: bool,
        quantized_output: bool,
        quantized_grad_output: bool,
        quantized_grad_input: bool,
Tim Moon's avatar
Tim Moon committed
983
984
    ) -> None:
        """GEMM with FP8 inputs and outputs"""
985
986
        if quantization is None:
            pytest.skip("Skipping case without quantization")
Tim Moon's avatar
Tim Moon committed
987
988
        self._test_basic_linear(
            dtype=torch.bfloat16,
989
990
991
992
993
994
995
            quantization=quantization,
            quantized_compute=quantized_compute,
            quantized_input=quantized_input,
            quantized_weight=quantized_weight,
            quantized_output=quantized_output,
            quantized_grad_output=quantized_grad_output,
            quantized_grad_input=quantized_grad_input,
Tim Moon's avatar
Tim Moon committed
996
997
        )

998
    @pytest.mark.parametrize("bias", (False, True))
999
1000
    @pytest.mark.parametrize("quantization", _quantization_list)
    @pytest.mark.parametrize("quantized_compute", (False, True))
1001
    @pytest.mark.parametrize("quantized_weight", (False, True))
1002
1003
    @pytest.mark.parametrize("input_requires_grad", (False, True))
    @pytest.mark.parametrize("weight_requires_grad", (False, True))
1004
1005
1006
1007
    def test_linear(
        self,
        *,
        bias: bool,
1008
1009
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1010
1011
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
1012
        quantization: Optional[str],
1013
        quantized_compute: bool,
1014
        quantized_weight: bool,
1015
1016
        input_requires_grad: bool,
        weight_requires_grad: bool,
1017
1018
1019
1020
1021
1022
1023
1024
1025
    ) -> None:
        """GEMM + bias"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1026
1027
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
1028
1029
1030
1031
        if quantization is None and (quantized_compute or quantized_weight):
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not (quantized_compute or quantized_weight):
            pytest.skip("Quantization scheme is not used")
1032
1033
1034
1035

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1036
            quantization=quantization,
1037
1038
1039
1040
1041
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1042
            quantization=quantization,
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1055
            quantization=quantization,
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1066
1067
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            op = te_ops.Linear(
                in_features,
                out_features,
                bias=bias,
                device=device,
                dtype=dtype,
            )
        with torch.no_grad():
            op.weight.copy_(w_test)
            if bias:
                op.bias.copy_(b_test)
            del w_test
            del b_test
1081
1082
            for param in op.parameters():
                param.requires_grad_(requires_grad=weight_requires_grad)
1083
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1084
            y_test = op(x_test)
1085
1086
        if input_requires_grad or weight_requires_grad:
            y_test.backward(dy_test)
1087
1088
1089
1090
1091

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1092
1093
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1094
1095
1096
1097

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
1098
1099
1100
1101
1102
1103
1104
1105
1106
        if input_requires_grad:
            dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        if weight_requires_grad:
            dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(dw_test, w_ref.grad, **tols)
            if bias:
                db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
                torch.testing.assert_close(db_test, b_ref.grad, **tols)
1107

1108
1109
    @pytest.mark.parametrize("weight_shape", ((7, 2), (32,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1110
1111
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1112
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
1113
1114
1115
1116
1117
1118
1119
1120
1121
    def test_layer_norm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1122
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1123
1124
1125
1126
1127
1128
1129
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1130
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=(w_ref + 1 if zero_centered_gamma else w_ref),
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
1178
1179
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1180
1181
        forward = te_ops.Sequential(
            op,
1182
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1183
        )
1184
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1185
1186
1187
1188
1189
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1190
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

    def test_layer_norm_autocast(
        self,
        *,
        weight_shape: Iterable[int] = (32,),
        in_shape: Iterable[int] = (32,),
        dtype: torch.dtype = torch.float16,
        autocast_dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        eps: float = 0.3,
    ) -> None:
        """Layer norm with PyTorch autocast"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=autocast_dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.layer_norm(
            x_ref,
            weight_shape,
            weight=w_ref,
            bias=b_ref,
            eps=eps,
        )
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.LayerNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            op.bias.copy_(b_test)
            del w_test
            del b_test
        with torch.autocast(device, dtype=autocast_dtype):
            y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        assert y_test.dtype == autocast_dtype
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dx_test, x_ref.grad, **dtype_tols(autocast_dtype))
        torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype))
        torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype))

1278
1279
    @pytest.mark.parametrize("weight_shape", ((19,), (64,)))
    @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Tim Moon's avatar
Tim Moon committed
1280
1281
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("zero_centered_gamma", (False, True))
1282
    @pytest.mark.parametrize("quantization", _quantization_list)
Tim Moon's avatar
Tim Moon committed
1283
1284
1285
1286
1287
1288
1289
1290
1291
    def test_rmsnorm(
        self,
        *,
        weight_shape: Iterable[int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 0.3,
        zero_centered_gamma: bool,
1292
        quantization: Optional[str],
Tim Moon's avatar
Tim Moon committed
1293
1294
1295
1296
1297
1298
1299
    ) -> None:
        """Layer norm"""

        # Make input and weight shapes consistent
        in_shape = list(in_shape)[:-1] + list(weight_shape)

        # Skip invalid configurations
1300
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
Tim Moon's avatar
Tim Moon committed
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            weight_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
        var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
        if zero_centered_gamma:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
        else:
            y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.RMSNorm(
            weight_shape,
            eps=eps,
            device=device,
            dtype=dtype,
            zero_centered_gamma=zero_centered_gamma,
        )
        with torch.no_grad():
            op.weight.copy_(w_test)
            del w_test
1340
1341
        quantized_compute = quantization is not None
        recipe = make_recipe(quantization)
Tim Moon's avatar
Tim Moon committed
1342
1343
        forward = te_ops.Sequential(
            op,
1344
            te_ops.Quantize(forward=quantized_compute, backward=False),
Tim Moon's avatar
Tim Moon committed
1345
        )
1346
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
Tim Moon's avatar
Tim Moon committed
1347
1348
1349
1350
1351
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1352
        if quantized_compute:
Tim Moon's avatar
Tim Moon committed
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
    @pytest.mark.parametrize("in_shape", ((32,), (6, 16, 64), (32, 64)))
    @pytest.mark.parametrize("dtype", _dtypes)
    def test_l2normalization(
        self,
        *,
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        eps: float = 1e-6,
    ) -> None:
        """L2 Normalization"""

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        # L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps)
        l2_norm_squared = x_ref.pow(2).sum(dim=-1, keepdim=True)
        rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
        y_ref = x_ref * rsqrt_norm
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.L2Normalization(
            eps=eps,
        )
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")

        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
Tim Moon's avatar
Tim Moon committed
1411

1412
    @pytest.mark.parametrize("in_place", (True, False))
1413
1414
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
1415
    @pytest.mark.parametrize("quantization", _quantization_list)
1416
    def test_add_extra_input(
1417
1418
        self,
        *,
1419
        in_shape: Iterable[int] = (32, 32),
1420
        in_place: bool,
1421
1422
        dtype: torch.dtype,
        device: torch.device,
1423
        quantization: Optional[str],
1424
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1425
1426
1427
1428
1429
        """Add two tensors

        Join in compute graph.

        """
1430
1431

        # Skip invalid configurations
1432
1433
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1434
1435
1436
1437

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
1438
            quantization=quantization,
1439
1440
            test_dtype=dtype,
            test_device=device,
1441
            test_is_quantized=with_quantization,
1442
1443
1444
        )
        x2_ref, x2_test = make_reference_and_test_tensors(
            in_shape,
1445
            quantization=quantization,
1446
1447
            test_dtype=dtype,
            test_device=device,
1448
            test_is_quantized=with_quantization,
1449
1450
1451
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
1452
            quantization=quantization,
1453
1454
            test_dtype=dtype,
            test_device=device,
1455
            test_is_quantized=with_quantization,
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x2_ref.detach()
        y_ref += x1_ref
        dx1_ref = dy_ref
        dx2_ref = dy_ref

        # Implementation with fusible operation
1466
        op = te_ops.AddExtraInput(in_place=in_place)
1467
1468
1469
1470
1471
        y_test = op(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
1472
        if with_quantization:
1473
1474
1475
1476
1477
1478
1479
1480
            tols = dtype_tols(x1_test._fp8_dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)

1481
    @pytest.mark.parametrize("in_place", (True, False))
1482
1483
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", ("cuda", "cpu"))
1484
    @pytest.mark.parametrize("quantization", _quantization_list)
1485
1486
1487
    def test_make_extra_output(
        self,
        *,
1488
        in_shape: Iterable[int] = (32, 32),
1489
        in_place: bool,
1490
1491
        dtype: torch.dtype,
        device: torch.device,
1492
        quantization: Optional[str],
1493
    ) -> None:
Tim Moon's avatar
Tim Moon committed
1494
1495
1496
1497
1498
        """Output tensor twice

        Split in compute graph.

        """
1499
1500

        # Skip invalid configurations
1501
1502
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1503
1504
1505
1506

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1507
            quantization=quantization,
1508
1509
            test_dtype=dtype,
            test_device=device,
1510
            test_is_quantized=with_quantization,
1511
1512
1513
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            in_shape,
1514
            quantization=quantization,
1515
1516
            test_dtype=dtype,
            test_device=device,
1517
            test_is_quantized=with_quantization,
1518
1519
1520
1521
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            in_shape,
1522
            quantization=quantization,
1523
1524
            test_dtype=dtype,
            test_device=device,
1525
            test_is_quantized=with_quantization,
1526
1527
1528
1529
1530
1531
1532
1533
1534
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = x_ref
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operation
1535
        op = te_ops.MakeExtraOutput(in_place=in_place)
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
        y1_test, y2_test = op(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check results
        tols = dtype_tols(dtype)
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, rtol=0, atol=0)
        torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1548
    @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
1549
    @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
1550
    @pytest.mark.parametrize("dtype", _dtypes)
1551
    @pytest.mark.parametrize("quantization", _quantization_list)
1552
    @pytest.mark.parametrize("cache_quantized_input", (False, True))
1553
1554
1555
1556
1557
1558
1559
    def test_activation(
        self,
        *,
        activation: str,
        out_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1560
        quantization: Optional[str],
1561
        cache_quantized_input: bool,
1562
1563
1564
1565
1566
1567
1568
1569
1570
    ) -> None:
        """Activation functions"""

        # Tensor dimensions
        in_shape = list(out_shape)
        if activation in ("geglu", "reglu", "swiglu"):
            in_shape[-1] *= 2

        # Skip invalid configurations
1571
1572
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1573
        if cache_quantized_input:
1574
            maybe_skip_quantization("fp8_current_scaling", device=device)
1575
1576
1577
1578

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1579
            quantization="fp8_current_scaling" if cache_quantized_input else None,
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref: torch.Tensor
        if activation == "gelu":
            y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
        elif activation == "relu":
            y_ref = torch.nn.functional.relu(x_ref)
        elif activation == "geglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
        elif activation == "reglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.relu(x1) * x2
        elif activation == "swiglu":
            x1, x2 = x_ref.chunk(2, dim=-1)
            y_ref = torch.nn.functional.silu(x1) * x2
        else:
            raise ValueError(f"Unexpected activation function ({activation})")
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1610
        recipe = make_recipe(quantization)
1611
1612
1613
1614
1615
1616
1617
1618
        make_op = dict(
            gelu=te_ops.GELU,
            relu=te_ops.ReLU,
            geglu=te_ops.GEGLU,
            reglu=te_ops.ReGLU,
            swiglu=te_ops.SwiGLU,
        )[activation]
        forward = te_ops.Sequential(
1619
            te_ops.Quantize(forward=False, backward=quantized_compute),
1620
            make_op(cache_quantized_input=cache_quantized_input),
1621
            te_ops.Quantize(forward=quantized_compute, backward=False),
1622
        )
1623
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1624
1625
1626
1627
1628
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1629
        if quantized_compute or cache_quantized_input:
1630
1631
1632
1633
1634
1635
1636
1637
1638
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("dtype", _dtypes)
1639
    @pytest.mark.parametrize("quantization", _quantization_list)
1640
1641
    @pytest.mark.parametrize("quantize_forward", (False, True))
    @pytest.mark.parametrize("quantize_backward", (False, True))
1642
1643
1644
    def test_swiglu(
        self,
        *,
1645
        out_shape: Iterable[int] = (32, 32),
1646
1647
        dtype: torch.dtype,
        device: torch.device = "cuda",
1648
1649
1650
        quantization: Optional[str],
        quantize_forward: bool,
        quantize_backward: bool,
1651
1652
1653
1654
1655
1656
1657
    ):

        # Tensor dimensions
        in_shape = list(out_shape)
        in_shape[-1] *= 2

        # Skip invalid configurations
1658
1659
1660
1661
        quantized_compute = quantization is not None
        if not quantized_compute and (quantize_forward or quantize_backward):
            pytest.skip("Quantization scheme has not been provided")
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        x1, x2 = x_ref.chunk(2, dim=-1)
        y_ref = torch.nn.functional.silu(x1) * x2
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
1682
        recipe = make_recipe(quantization)
1683
        forward = te_ops.Sequential(
1684
            te_ops.Quantize(forward=False, backward=quantize_backward),
1685
            te_ops.SwiGLU(),
1686
            te_ops.Quantize(forward=quantize_forward, backward=False),
1687
        )
1688
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1689
1690
1691
1692
1693
            y_test = forward(x_test)
        y_test.backward(dy_test)

        # Expected numerical error
        tols = dtype_tols(dtype)
1694
        if quantized_compute:
1695
1696
1697
1698
1699
1700
1701
1702
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
    @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
    @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("device", _devices)
    def test_constant_scale(
        self,
        *,
        scale: float,
        shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device,
    ):

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = scale * x_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operation
        op = te_ops.ConstantScale(scale)
        y_test = op(x_test)
        y_test.backward(dy_test)

        # Check results
        tols = dtype_tols(dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)

    @pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
    @pytest.mark.parametrize("is_training", (True, False))
    @pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
    @pytest.mark.parametrize("dtype", _dtypes)
    def test_dropout(
        self,
        *,
        prob: float,
        is_training: bool,
        shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
    ):

        # Random data
        x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
        x_test = x_ref.clone().requires_grad_()
        dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
        dy_test = dy_ref.clone()

        # Apply dropout
        op = te_ops.Dropout(prob)
        if is_training:
            op.train()
        else:
            op.eval()
        y = op(x_test)
        y.backward(dy_test)

        # Check values
        if is_training:
            mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
            torch.testing.assert_close(y, x_ref * mask)
            torch.testing.assert_close(x_test.grad, dy_ref * mask)
        else:
            torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
            torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)

        # Hypothesis testing for number of zeros
        # Note: A Bernoulli random variable with probability p has
        # mean p and standard deviation sqrt(p*(1-p)). By the central
        # limit theorem, the mean of n iid Bernoulli variables
        # converges to a normal random variable with mean p and
        # standard deviation sqrt(p*(1-p)/n). If the observed mean is
        # below the 0.5th or above the 99.5th percentiles, then the
        # p-value is less than 1% and we assume that the dropout
        # distribution is incorrect.
        if is_training:
            prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
            z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
            assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"

1797
1798
1799
1800
1801
1802

class TestFusedOps:
    """Tests for fused operations"""

    @staticmethod
    def setup_class(cls) -> None:
1803
        reset_rng_states()
1804

1805
1806
    @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
    @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
1807
    @pytest.mark.parametrize("dtype", _dtypes)
1808
    @pytest.mark.parametrize("quantization", _quantization_list)
1809
    @pytest.mark.parametrize("quantized_weight", (False, True))
1810
    def test_forward_linear_bias_activation(
1811
1812
1813
1814
1815
1816
1817
        self,
        *,
        bias: bool = True,
        weight_shape: tuple[int, int],
        in_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
1818
1819
        quantization: Optional[str],
        quantized_weight: bool,
1820
    ) -> None:
1821
        """Forward GEMM + bias + activation"""
1822
1823
1824
1825
1826
1827
1828

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1829
1830
1831
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
1832
1833
1834
1835
1836
1837
1838
1839
        if dtype not in (torch.float16, torch.bfloat16):
            pytest.skip(
                "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output"
            )

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
1840
            quantization=quantization,
1841
1842
1843
1844
1845
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1846
            quantization=quantization,
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1859
            quantization=quantization,
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1870
1871
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1887
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
            y_test = model(x_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasActivation)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
1900
1901
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

1914
1915
    @pytest.mark.parametrize("bias", (False, True))
    @pytest.mark.parametrize("dtype", _dtypes)
1916
    @pytest.mark.parametrize("quantization", _quantization_list)
1917
1918
1919
1920
    def test_forward_linear_bias_add(
        self,
        *,
        bias: bool,
1921
1922
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
1923
1924
        dtype: torch.dtype,
        device: torch.device = "cuda",
1925
1926
        quantization: Optional[str],
        quantized_weight: bool = False,
1927
1928
1929
1930
1931
1932
1933
1934
1935
    ) -> None:
        """Forward GEMM + bias + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
1936
1937
1938
1939
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
1940
1941
1942
1943
1944
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
1945
            quantization=quantization,
1946
1947
1948
1949
1950
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
1951
            quantization=quantization,
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = None, None
        if bias:
            b_ref, b_test = make_reference_and_test_tensors(
                out_features,
                test_dtype=dtype,
                test_device=device,
            )
        x2_ref, x2_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
1969
            quantization=quantization,
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x1_ref, w_ref, bias=b_ref) + x2_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
1980
1981
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
1982
1983
1984
1985
1986
1987
1988
1989
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=bias,
                    device=device,
                    dtype=dtype,
                ),
1990
                te_ops.AddExtraInput(in_place=True),
1991
1992
1993
1994
1995
1996
1997
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            if bias:
                model[0].bias.copy_(b_test)
            del w_test
            del b_test
1998
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
            y_test = model(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 1
        assert isinstance(forward_ops[0][0], ForwardLinearBiasAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
2011
2012
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
        torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
        if bias:
            db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(db_test, b_ref.grad, **tols)

Jan Bielak's avatar
Jan Bielak committed
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
    @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("quantization", _quantization_list)
    def test_forward_linear_scale_add(
        self,
        *,
        scale: float,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype,
        device: torch.device = "cuda",
        quantization: Optional[str],
        quantized_weight: bool = False,
    ) -> None:
        """Forward GEMM + scale + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x1_ref, x1_test = make_reference_and_test_tensors(
            in_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
        )
        x2_ref, x2_test = make_reference_and_test_tensors(
            out_shape,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x1_ref, w_ref) * scale + x2_ref
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=False,
                    device=device,
                    dtype=dtype,
                ),
                te_ops.ConstantScale(scale),
                te_ops.AddExtraInput(in_place=True),
                te_ops.Quantize(),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            del w_test
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
            y_test = model(x1_test, x2_test)
        y_test.backward(dy_test)

        # Check that forward operations have been fused
        forward_ops = model._module_groups[0]._forward_ops
        assert len(forward_ops) == 2
        assert isinstance(forward_ops[0][0], ForwardLinearScaleAdd)
        assert isinstance(forward_ops[1][0], te_ops.Quantize)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
        dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
        torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)

2130
2131
2132
2133
    @pytest.mark.parametrize("activation", ("relu", "gelu"))
    @pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("quantization", _quantization_list)
Jan Bielak's avatar
Jan Bielak committed
2134
    def test_backward_activation_bias(
2135
2136
2137
2138
2139
2140
2141
2142
        self,
        *,
        activation: str,
        out_shape: Iterable[int],
        dtype: torch.dtype,
        device: torch.device = "cuda",
        quantization: Optional[str],
    ) -> None:
Jan Bielak's avatar
Jan Bielak committed
2143
        """Backward dact + dbias + quantize"""
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199

        # Tensor dimensions
        in_shape = list(out_shape)
        hidden_size = in_shape[-1]

        # Skip invalid configurations
        with_quantization = quantization is not None
        maybe_skip_quantization(quantization, device=device)
        if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
            pytest.skip("Unsupported tensor size for MXFP8")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
        )
        b_ref, b_test = make_reference_and_test_tensors(
            hidden_size,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            in_shape,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [hidden_size])
        if activation == "gelu":
            y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
        elif activation == "relu":
            y_ref = torch.nn.functional.relu(y_ref)
        else:
            raise ValueError(f"Unexpected activation function ({activation})")
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
        recipe = make_recipe(quantization)
        act_type = te_ops.GELU if activation == "gelu" else te_ops.ReLU
        model = te_ops.Sequential(
            te_ops.Quantize(forward=False, backward=True),
            te_ops.Bias(hidden_size, device=device, dtype=dtype),
            act_type(),
        )
        with torch.no_grad():
            model[1].bias.copy_(b_test)
            del b_test
        with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
            y_test = model(x_test)
        y_test.backward(dy_test)

        # Check that backward operations have been fused
        backward_ops = model._module_groups[0]._backward_ops
2200
        if with_quantization:
2201
            assert len(backward_ops) == 2
Jan Bielak's avatar
Jan Bielak committed
2202
            assert isinstance(backward_ops[0][0], BackwardActivationBias)
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
            assert isinstance(backward_ops[1][0], te_ops.Quantize)
        else:
            assert len(backward_ops) == 3
            assert isinstance(backward_ops[0][0], act_type)
            assert isinstance(backward_ops[1][0], te_ops.Bias)
            assert isinstance(backward_ops[2][0], te_ops.Quantize)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if with_quantization:
            tols = dtype_tols(tex.DType.kFloat8E4M3)

2215
        # Check results
2216
2217
2218
2219
2220
2221
2222
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(db_test, b_ref.grad, **tols)

2223
    @pytest.mark.parametrize("dtype", _dtypes)
2224
    @pytest.mark.parametrize("quantization", _quantization_list)
2225
2226
2227
    def test_backward_linear_add(
        self,
        *,
2228
2229
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
2230
2231
        dtype: torch.dtype,
        device: torch.device = "cuda",
2232
2233
        quantization: Optional[str],
        quantized_weight: bool = False,
2234
2235
2236
2237
2238
2239
2240
2241
2242
    ) -> None:
        """Backward dgrad GEMM + add"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
2243
2244
2245
2246
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
2247
2248
2249
2250
2251
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
2252
            quantization=quantization,
2253
2254
2255
2256
2257
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
2258
            quantization=quantization,
2259
2260
2261
2262
2263
            test_dtype=dtype,
            test_device=device,
        )
        dy1_ref, dy1_test = make_reference_and_test_tensors(
            out_shape,
2264
            quantization=quantization,
2265
2266
2267
2268
2269
2270
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )
        dy2_ref, dy2_test = make_reference_and_test_tensors(
            out_shape,
2271
            quantization=quantization,
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y1_ref = torch.nn.functional.linear(x_ref, w_ref)
        y2_ref = x_ref
        (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

        # Implementation with fusible operations
2283
2284
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight):
2285
            model = te_ops.Sequential(
2286
                te_ops.MakeExtraOutput(in_place=True),
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=False,
                    device=device,
                    dtype=dtype,
                ),
            )
        with torch.no_grad():
            model[1].weight.copy_(w_test)
            del w_test
2298
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
            y1_test, y2_test = model(x_test)
        (y1_test * dy1_test + y2_test * dy2_test).sum().backward()

        # Check that backward operations have been fused
        backward_ops = model._module_groups[0]._backward_ops
        assert len(backward_ops) == 1
        assert isinstance(backward_ops[0][0], BackwardLinearAdd)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
2311
2312
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322

        # Check results
        y1_test = y1_test.to(dtype=torch.float64, device="cpu")
        y2_test = y2_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y1_test, y1_ref, **tols)
        torch.testing.assert_close(y2_test, y2_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
2323

Jan Bielak's avatar
Jan Bielak committed
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
    @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("quantization", _quantization_list)
    def test_backward_linear_scale(
        self,
        *,
        scale: float,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype,
        device: torch.device = "cuda",
        quantization: Optional[str],
        quantized_weight: bool = False,
    ) -> None:
        """Backward dgrad GEMM + scale"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)
        if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
            pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")

        # Random data
        x_ref, x_test = make_reference_and_test_tensors(
            in_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
        )
        w_ref, w_test = make_reference_and_test_tensors(
            (out_features, in_features),
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
        )
        dy_ref, dy_test = make_reference_and_test_tensors(
            out_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Plain PyTorch implementation
        y_ref = torch.nn.functional.linear(x_ref, w_ref) * scale
        y_ref.backward(dy_ref)

        # Implementation with fusible operations
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight):
            model = te_ops.Sequential(
                te_ops.Linear(
                    in_features,
                    out_features,
                    bias=False,
                    device=device,
                    dtype=dtype,
                ),
                te_ops.ConstantScale(scale),
            )
        with torch.no_grad():
            model[0].weight.copy_(w_test)
            del w_test
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
            y_test = model(x_test)
        (y_test * dy_test).sum().backward()

        # Check that backward operations have been fused
        backward_ops = model._module_groups[0]._backward_ops
        assert len(backward_ops) == 1
        assert isinstance(backward_ops[0][0], BackwardLinearScale)

        # Expected numerical error
        tols = dtype_tols(dtype)
        if dtype == torch.float32:
            tols = dtype_tols(torch.float16)  # TF32 GEMM
        if quantized_compute:
            tols = dtype_tols(tex.DType.kFloat8E4M3)

        # Check results
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
        dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
        torch.testing.assert_close(y_test, y_ref, **tols)
        torch.testing.assert_close(dx_test, x_ref.grad, **tols)
        torch.testing.assert_close(dw_test, w_ref.grad, **tols)
2416
2417
2418
2419
2420
2421
2422


class TestCheckpointing:
    """Tests for checkpointing"""

    @staticmethod
    def setup_class(cls) -> None:
2423
        reset_rng_states()
2424

2425
    @pytest.mark.parametrize("quantization", _quantization_list)
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
    @pytest.mark.parametrize("quantized_weight", (False, True))
    def test_linear(
        self,
        *,
        pre_checkpoint_steps: int = 2,
        post_checkpoint_steps: int = 2,
        weight_shape: tuple[int, int] = (32, 32),
        in_shape: Iterable[int] = (32, -1),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        quantization: Optional[str],
        quantized_weight: bool,
    ) -> None:
        """Check checkpointing with linear op"""

        # Make input and weight shapes consistent
        out_features, in_features = weight_shape
        in_shape = list(in_shape)[:-1] + [in_features]
        out_shape = in_shape[:-1] + [out_features]

        # Skip invalid configurations
        quantized_compute = quantization is not None
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=out_shape)

        # Construct model
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            model_save = te_ops.Sequential(
                te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
            )
        optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25)

        # Warmup training steps
        for _ in range(pre_checkpoint_steps):
            x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
            dy = torch.randn(out_shape, dtype=dtype, device=device)
            optim_save.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_save(x)
            y.backward(dy)
            optim_save.step()

        # Save checkpoint
        byte_stream = io.BytesIO()
        torch.save(
            {"model": model_save.state_dict(), "optim": optim_save.state_dict()},
            byte_stream,
        )
        checkpoint_bytes = byte_stream.getvalue()
        del byte_stream

        # Synthetic data for evaluation
        xs_save = [
            torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
            for _ in range(post_checkpoint_steps)
        ]
        with torch.no_grad():
            xs_load = [x.clone().requires_grad_() for x in xs_save]
        dys = [
            torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps)
        ]

        # Training steps with original model
        ys_save = []
        for i in range(post_checkpoint_steps):
            optim_save.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_save(xs_save[i])
            y.backward(dys[i])
            optim_save.step()
            ys_save.append(y)

        # Load checkpoint
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            model_load = te_ops.Sequential(
                te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
            )
        optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
        state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
        model_load.load_state_dict(state_dict["model"])
        optim_load.load_state_dict(state_dict["optim"])

        # Training steps with loaded model
        ys_load = []
        for i in range(post_checkpoint_steps):
            optim_load.zero_grad()
            with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
                y = model_load(xs_load[i])
            y.backward(dys[i])
            optim_load.step()
            ys_load.append(y)

        # Check that original and loaded model match exactly
        tols = {"rtol": 0, "atol": 0}
        for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
            torch.testing.assert_close(param_load, param_save, **tols)
            torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
        for y_load, y_save in zip(ys_load, ys_save):
            torch.testing.assert_close(y_load, y_save, **tols)
        for x_load, x_save in zip(xs_load, xs_save):
            torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
2528
2529
2530
2531
2532
2533
2534


class TestSequentialModules:
    """Test for larger Sequentials with modules commonly used together"""

    @staticmethod
    def setup_class(cls) -> None:
2535
        reset_rng_states()
2536

Jan Bielak's avatar
Jan Bielak committed
2537
    @pytest.mark.parametrize("requires_grad", (False, True))
2538
2539
2540
2541
2542
2543
2544
2545
2546
    @pytest.mark.parametrize("bias", (False, True))
    @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
    @pytest.mark.parametrize("quantized_compute", (False, True))
    @pytest.mark.parametrize("quantized_weight", (False, True))
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("quantization", _quantization_list)
    def test_layernorm_mlp(
        self,
        *,
Jan Bielak's avatar
Jan Bielak committed
2547
        requires_grad: bool,
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
        bias: bool,
        normalization: str,
        quantized_compute: bool,
        quantized_weight: bool,
        dtype: torch.dtype,
        quantization: Optional[str],
        device: torch.device = "cuda",
        hidden_size: int = 32,
        sequence_length: int = 512,
        batch_size: int = 4,
        ffn_hidden_size: int = 64,
        layernorm_epsilon: float = 1e-5,
    ) -> None:
        """
        LayerNorm/RMSNorm + Linear + GELU + Linear

        Note that this test checks only if the module runs
        as when chaining multiple modules it is hard to validate
        numerical accuracy.
        """

        # Make input shape
        in_shape = (sequence_length, batch_size, hidden_size)
        ffn_shape = in_shape[:-1] + (ffn_hidden_size,)

        # Skip invalid configurations
        maybe_skip_quantization(quantization, dims=in_shape, device=device)
        maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
        quantization_needed = quantized_compute or quantized_weight
        if quantization is None and quantization_needed:
            pytest.skip("Quantization scheme is not specified")
        if quantization is not None and not quantization_needed:
            pytest.skip("Quantization scheme is not used")

        # Random data
        _, x_test = make_reference_and_test_tensors(
            in_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
Jan Bielak's avatar
Jan Bielak committed
2588
            requires_grad=requires_grad,
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
        )
        _, dy_test = make_reference_and_test_tensors(
            in_shape,
            quantization=quantization,
            test_dtype=dtype,
            test_device=device,
            requires_grad=False,
        )

        # Implementation with fusible operations
        recipe = make_recipe(quantization)
        with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
            if normalization == "LayerNorm":
                norm = te_ops.LayerNorm(
                    hidden_size,
                    eps=layernorm_epsilon,
                    device=device,
                    dtype=dtype,
                )
            else:
                norm = te_ops.RMSNorm(
                    hidden_size,
                    eps=layernorm_epsilon,
                    device=device,
                    dtype=dtype,
                )
            ffn1 = te_ops.Linear(
                hidden_size,
                ffn_hidden_size,
                bias=bias,
                device=device,
                dtype=dtype,
            )
            act = te_ops.GELU()
            ffn2 = te_ops.Linear(
                ffn_hidden_size,
                hidden_size,
                bias=bias,
                device=device,
                dtype=dtype,
            )
        forward = te_ops.Sequential(norm, ffn1, act, ffn2)
        with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
            y_test = forward(x_test)
        y_test.backward(dy_test)