test_jit.py 4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
#
# See LICENSE for license information.

from typing import Tuple

import pytest
import torch

import transformer_engine.pytorch as te

# Model names for test_torch_dynamo
13
_model_factory = {
14
15
16
17
18
    "Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
    "LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
    "LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
    "LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
    "TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
19
}
20
21
22


@pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available")
23
@pytest.mark.parametrize("model_name", list(_model_factory.keys()))
24
25
26
27
28
29
30
31
32
33
def test_torch_dynamo(model_name: str):
    """Test compatibility with Torch Dynamo

    Construct model, optimize with Torch Dynamo, and perform a single
    forward and backward pass.

    """

    # Helper function to construct tensor with default options
    def make_tensor(
34
35
36
37
38
        dims: Tuple[int],
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
        requires_grad: bool = True,
        **kwargs,
39
40
41
42
43
44
45
46
47
48
    ):
        return torch.zeros(
            dims,
            dtype=dtype,
            device=device,
            requires_grad=requires_grad,
            **kwargs,
        )

    # Construct model and input tensors
49
50
51
    model_builder, input_builder = _model_factory[model_name]
    model = model_builder()
    inputs = [make_tensor(input_builder)]
52
53
54
55
56
57
58

    # Optimize model with TorchDynamo
    torch.compile(model)

    # Forward and backward pass
    out = model(*inputs)
    out.backward(torch.zeros_like(out))
59
60
61
62
63
64
65


def test_lazy_compile():
    """Smoke test to ensure lazy compilation is working."""
    from transformer_engine.pytorch.jit import dgelu_fused_

    dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10))
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


def test_l2normalization_fused():
    """Smoke test for L2Normalization fusion functions."""
    from transformer_engine.pytorch.jit import (
        l2normalization_fused,
        l2normalization_fwd_fused,
        l2normalization_backward_fused,
    )

    # Basic smoke test like other JIT functions
    x = torch.randn(10, 128, device="cuda", dtype=torch.float32)
    eps = 1e-6

    # Test inference version
    output_inf = l2normalization_fused(x, eps)

    # Test training version with backward
    x_train = torch.randn(10, 128, device="cuda", dtype=torch.float32, requires_grad=True)
    output_train, rsqrt_norm = l2normalization_fwd_fused(x_train, eps)
    grad_output = torch.randn_like(output_train)
    grad_input = l2normalization_backward_fused(grad_output, x_train, rsqrt_norm, eps)


def test_l2normalization_fused_correctness():
    """Simple verification that L2Normalization fusion matches reference implementation."""
    from transformer_engine.pytorch.jit import (
        l2normalization_fwd_fused,
        l2normalization_backward_fused,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    x = torch.randn(16, 64, device=device, dtype=torch.float32, requires_grad=True)
    eps = 1e-6

    # Test fused forward
    output_fused, rsqrt_norm = l2normalization_fwd_fused(x, eps)

    # Reference implementation
    x_ref = x.clone().detach().requires_grad_(True)
    x_squared = x_ref.pow(2)
    l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
    rsqrt_norm_ref = torch.rsqrt(l2_norm_squared + eps)
    output_ref = x_ref * rsqrt_norm_ref

    # Check forward pass matches
    torch.testing.assert_close(output_fused, output_ref, atol=1e-6, rtol=1e-5)
    torch.testing.assert_close(rsqrt_norm, rsqrt_norm_ref, atol=1e-6, rtol=1e-5)

    # Test fused backward
    grad_output = torch.randn_like(output_fused)
    grad_input_fused = l2normalization_backward_fused(grad_output, x, rsqrt_norm, eps)

    # Reference backward
    output_ref.backward(grad_output)
    grad_input_ref = x_ref.grad

    # Check backward pass matches
    torch.testing.assert_close(grad_input_fused, grad_input_ref, atol=1e-5, rtol=1e-4)