test_math_ieee_math.py 7.93 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import pytest


def run_ieee_math_test(mathop_name,
                       mathop_func,
                       rounding_mode="rn",
                       M=128,
                       N=128,
                       block_M=32,
                       block_N=32,
                       dtype="float32"):
    """
    Test IEEE-compliant math operations with specified rounding modes.
    """

    # Define the appropriate function based on operation type to avoid TVM parsing conflicts
    if mathop_name == "ieee_fmaf":

        @T.prim_func
        def main_func(
                A: T.Tensor((M, N), dtype),
                B: T.Tensor((M, N), dtype),
                C: T.Tensor((M, N), dtype),
                D: T.Tensor((M, N), dtype),
        ):
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
                for i, j in T.Parallel(block_M, block_N):
                    D[by * block_M + i,
                      bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
                                                      B[by * block_M + i, bx * block_N + j],
                                                      C[by * block_M + i,
                                                        bx * block_N + j], rounding_mode)

        out_idx = [3]
        num_inputs = 3
    elif mathop_name in ["ieee_add", "ieee_sub", "ieee_mul", "ieee_fdiv"]:

        @T.prim_func
        def main_func(
                A: T.Tensor((M, N), dtype),
                B: T.Tensor((M, N), dtype),
                C: T.Tensor((M, N), dtype),
        ):
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
                for i, j in T.Parallel(block_M, block_N):
                    C[by * block_M + i,
                      bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
                                                      B[by * block_M + i,
                                                        bx * block_N + j], rounding_mode)

        out_idx = [2]
        num_inputs = 2
    else:  # Single argument operations

        @T.prim_func
        def main_func(
                A: T.Tensor((M, N), dtype),
                B: T.Tensor((M, N), dtype),
        ):
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
                for i, j in T.Parallel(block_M, block_N):
                    B[by * block_M + i,
                      bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
                                                      rounding_mode)

        out_idx = [1]
        num_inputs = 1

    # Test compilation
    kernel = tilelang.compile(
        main_func,
        out_idx=out_idx,
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
        })

    print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===")
    print(f"✓ {mathop_name} compilation test passed")

    # Test numerical execution
    torch_dtype = getattr(torch, dtype)
    a = torch.randn(M, N, device="cuda", dtype=torch_dtype)

    if num_inputs >= 2:
        b = torch.randn(M, N, device="cuda", dtype=torch_dtype)
    if num_inputs == 3:
        c = torch.randn(M, N, device="cuda", dtype=torch_dtype)

    # Ensure positive values for functions that need them
    if mathop_name in ["ieee_frcp", "ieee_fsqrt"]:
        a = torch.abs(a) + 0.1
    elif mathop_name == "ieee_fdiv":
        b = torch.abs(b) + 0.1  # Avoid division by zero

    # Execute kernel
    try:
        if num_inputs == 1:
            result = kernel(a)
        elif num_inputs == 2:
            result = kernel(a, b)
        else:  # num_inputs == 3
            result = kernel(a, b, c)

        assert result is not None
        print(f"✓ {mathop_name} numerical execution test passed")
    except Exception as e:
        print(f"Warning: {mathop_name} execution failed: {e}")


def test_rounding_mode_validation():
    """Test that invalid rounding modes raise ValueError"""

    # Test with invalid rounding mode
    with pytest.raises(ValueError, match="Invalid rounding mode"):
        T.ieee_add(1.0, 2.0, "invalid_mode")

    with pytest.raises(ValueError, match="Invalid rounding mode"):
        T.ieee_mul(1.0, 2.0, "xy")

    with pytest.raises(ValueError, match="Invalid rounding mode"):
        T.ieee_fsqrt(4.0, "bad_mode")

    print("✓ Rounding mode validation test passed")


@tilelang.testing.requires_cuda
def test_ieee_add_all_rounding_modes():
    """Test IEEE addition with all rounding modes"""
    rounding_modes = ["rn", "rz", "ru", "rd"]

    for mode in rounding_modes:
        run_ieee_math_test("ieee_add", T.ieee_add, rounding_mode=mode)
        print(f"✓ ieee_add with {mode} passed")


@tilelang.testing.requires_cuda
def test_ieee_sub_all_rounding_modes():
    """Test IEEE subtraction with all rounding modes"""
    rounding_modes = ["rn", "rz", "ru", "rd"]

    for mode in rounding_modes:
        run_ieee_math_test("ieee_sub", T.ieee_sub, rounding_mode=mode)
        print(f"✓ ieee_sub with {mode} passed")


@tilelang.testing.requires_cuda
def test_ieee_mul_all_rounding_modes():
    """Test IEEE multiplication with all rounding modes"""
    rounding_modes = ["rn", "rz", "ru", "rd"]

    for mode in rounding_modes:
        run_ieee_math_test("ieee_mul", T.ieee_mul, rounding_mode=mode)
        print(f"✓ ieee_mul with {mode} passed")


@tilelang.testing.requires_cuda
def test_ieee_fmaf_all_rounding_modes():
    """Test IEEE fused multiply-add with all rounding modes"""
    rounding_modes = ["rn", "rz", "ru", "rd"]

    for mode in rounding_modes:
        run_ieee_math_test("ieee_fmaf", T.ieee_fmaf, rounding_mode=mode)
        print(f"✓ ieee_fmaf with {mode} passed")


@tilelang.testing.requires_cuda
def test_ieee_frcp_all_rounding_modes():
    """Test IEEE reciprocal with all rounding modes"""
    rounding_modes = ["rn", "rz", "ru", "rd"]

    for mode in rounding_modes:
        run_ieee_math_test("ieee_frcp", T.ieee_frcp, rounding_mode=mode)
        print(f"✓ ieee_frcp with {mode} passed")


@tilelang.testing.requires_cuda
def test_ieee_fsqrt_all_rounding_modes():
    """Test IEEE square root with all rounding modes"""
    rounding_modes = ["rn", "rz", "ru", "rd"]

    for mode in rounding_modes:
        run_ieee_math_test("ieee_fsqrt", T.ieee_fsqrt, rounding_mode=mode)
        print(f"✓ ieee_fsqrt with {mode} passed")


@tilelang.testing.requires_cuda
def test_ieee_frsqrt_rn_only():
    """Test IEEE reciprocal square root (round to nearest only)"""

    @T.prim_func
    def main(
            A: T.Tensor((128, 128), "float32"),
            B: T.Tensor((128, 128), "float32"),
    ):
        with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by):
            for i, j in T.Parallel(32, 32):
                B[by * 32 + i, bx * 32 + j] = T.ieee_frsqrt(A[by * 32 + i, bx * 32 + j])

    kernel = tilelang.compile(
        main,
        out_idx=[1],
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
        })

    print("\n=== Testing ieee_frsqrt (rn only) ===")
    print("✓ ieee_frsqrt compilation test passed")

    # Test numerical execution
    a = torch.abs(torch.randn(128, 128, device="cuda", dtype=torch.float32)) + 0.1

    try:
        result = kernel(a)
        assert result is not None
        print("✓ ieee_frsqrt numerical execution test passed")
    except Exception as e:
        print(f"Warning: ieee_frsqrt execution failed: {e}")


@tilelang.testing.requires_cuda
def test_ieee_fdiv_all_rounding_modes():
    """Test IEEE division with all rounding modes"""
    rounding_modes = ["rn", "rz", "ru", "rd"]

    for mode in rounding_modes:
        run_ieee_math_test("ieee_fdiv", T.ieee_fdiv, rounding_mode=mode)
        print(f"✓ ieee_fdiv with {mode} passed")


if __name__ == "__main__":
    tilelang.testing.main()