test_mathops_fastmath.py 10.2 KB
Newer Older
1
import pytest
2
3
4
5
6
7
8
9
10
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import re


def get_mathop_lines(source, mathop_name):
    """Extract lines containing the mathop from CUDA source for debugging"""
11
    lines = source.split("\n")
12
13
    relevant_lines = []
    for i, line in enumerate(lines):
14
        if mathop_name in line and ("(" in line):
15
16
17
18
19
            # Include some context
            start = max(0, i - 1)
            end = min(len(lines), i + 2)
            relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)])
            relevant_lines.append("---")
20
    return "\n".join(relevant_lines[-10:])  # Show last 10 lines to avoid too much output
21
22
23
24
25
26
27
28
29
30


def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
    """Check source for fastmath/non-fastmath versions"""
    fastmath_pattern = rf"__({mathop_name}f?)\b"
    non_fastmath_pattern = rf"(?<!__)({mathop_name}f?)\b"

    fastmath_matches = re.findall(fastmath_pattern, source)
    non_fastmath_matches = re.findall(non_fastmath_pattern, source)

31
    print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls")
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    if len(fastmath_matches) > 0:
        print(f"Fastmath calls found: {fastmath_matches}")
    if len(non_fastmath_matches) > 0:
        print(f"Non-fastmath calls found: {non_fastmath_matches}")
    print(f"Source preview for {mathop_name}:")
    print(get_mathop_lines(source, mathop_name))

    if expect_fastmath:
        assert len(fastmath_matches) > 0, "Expected fastmath calls but found none"
        print(f"✓ {mathop_name} correctly uses fastmath versions")
    else:
        assert len(fastmath_matches) == 0, f"Found unexpected fastmath calls: {fastmath_matches}"
        assert len(non_fastmath_matches) > 0, f"No {mathop_name} calls found"
        print(f"✓ {mathop_name} correctly uses non-fastmath versions")


def check_non_fastmath_usage(source, mathop_name):
    """Check that source uses non-fastmath versions (no __ prefix)"""
    check_fastmath_usage(source, mathop_name, expect_fastmath=False)


53
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
54
55
56
57
58
59
60
    """
    Test single-argument mathops.
    T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
    """

    @T.prim_func
    def main(
61
62
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
63
64
65
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            for i, j in T.Parallel(block_M, block_N):
66
                B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
67
68
69
70
71
72
73
74

    # Test with FAST_MATH disabled
    kernel_no_fastmath = tilelang.compile(
        main,
        out_idx=[1],
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
75
76
        },
    )
77
78
79
80
81
82
83
84
85
86
87
88

    source_no_fastmath = kernel_no_fastmath.get_kernel_source()

    print(f"\n=== Testing {mathop_name} ===")
    print("FAST_MATH=False:")

    # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
    check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)

    print(f"✓ {mathop_name} compilation and execution test passed")


89
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
90
91
92
93
94
95
    """
    Test two-argument mathops to ensure they generate non-fastmath CUDA code.
    """

    @T.prim_func
    def main(
96
97
98
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
        C: T.Tensor((M, N), dtype),
99
100
101
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            for i, j in T.Parallel(block_M, block_N):
102
103
104
                C[by * block_M + i, bx * block_N + j] = mathop_func(
                    A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]
                )
105
106
107
108
109
110
111
112

    # Test with FAST_MATH disabled
    kernel_no_fastmath = tilelang.compile(
        main,
        out_idx=[2],
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
113
114
        },
    )
115
116
117
118
119
120
121
122

    # Test with FAST_MATH enabled
    kernel_fastmath = tilelang.compile(
        main,
        out_idx=[2],
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
123
124
        },
    )
125
126
127
128
129
130
131
132
133
134
135
136

    source_no_fastmath = kernel_no_fastmath.get_kernel_source()
    source_fastmath = kernel_fastmath.get_kernel_source()

    print(f"\n=== Testing {mathop_name} (two args) ===")
    print("FAST_MATH=False:")
    check_non_fastmath_usage(source_no_fastmath, mathop_name)

    print("FAST_MATH=True:")
    check_non_fastmath_usage(source_fastmath, mathop_name)

    # Test numerical correctness
137
    torch_dtype = dtype.as_torch()
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
    b = torch.randn(M, N, device="cuda", dtype=torch_dtype)

    # Ensure positive values for functions that need them
    if mathop_name == "pow":
        a = torch.abs(a) + 0.1
        b = torch.clamp(b, -3, 3)  # Limit exponent range
    elif mathop_name == "fmod":
        b = torch.abs(b) + 0.1  # Avoid division by zero

    c_no_fastmath = kernel_no_fastmath(a, b)
    c_fastmath = kernel_fastmath(a, b)

    # Both should produce similar results
    torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3)
    print(f"✓ {mathop_name} numerical test passed")


def run_abs_test():
    """Test that abs correctly maps to fabs (not __fabsf) in generated CUDA code"""
    M, N = 128, 128
    block_M, block_N = 32, 32

    @T.prim_func
    def main(
163
164
        A: T.Tensor((M, N), T.float32),
        B: T.Tensor((M, N), T.float32),
165
166
167
168
169
170
171
172
173
174
175
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            for i, j in T.Parallel(block_M, block_N):
                B[by * block_M + i, bx * block_N + j] = T.abs(A[by * block_M + i, bx * block_N + j])

    kernel = tilelang.compile(
        main,
        out_idx=[1],
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
176
177
        },
    )
178
179
180
181
182
183
184
185
186
187
188
189
190
191

    source = kernel.get_kernel_source()
    print("\n=== Testing abs (maps to fabs) ===")
    check_non_fastmath_usage(source, "fabs")

    # Test numerical correctness
    a = torch.randn(M, N, device="cuda", dtype=torch.float32)
    b = kernel(a)
    expected = torch.abs(a)

    torch.testing.assert_close(b, expected, rtol=1e-5, atol=1e-5)
    print("✓ abs numerical test passed")


192
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
193
194
195
196
197
198
    """
    Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
    """

    @T.prim_func
    def main(
199
200
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
201
202
203
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            for i, j in T.Parallel(block_M, block_N):
204
                B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
205
206
207
208
209
210
211
212

    # Test with FAST_MATH enabled
    kernel_fastmath = tilelang.compile(
        main,
        out_idx=[1],
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
213
214
        },
    )
215
216
217
218
219
220

    source_fastmath = kernel_fastmath.get_kernel_source()

    print(f"\n=== Testing {mathop_name} (fastmath version) ===")
    print("FAST_MATH=True:")
    # Strip the __ prefix for checking in the CUDA source
221
    cuda_mathop_name = mathop_name.lstrip("_")
222
223
224
    check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)

    # Test numerical correctness
225
    torch_dtype = dtype.as_torch()
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    a = torch.randn(M, N, device="cuda", dtype=torch_dtype)

    # Ensure positive values for functions that need them
    if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]:
        a = torch.abs(a) + 0.1

    b_fastmath = kernel_fastmath(a)

    # Compare with reference implementation
    if cuda_mathop_name == "exp":
        expected = torch.exp(a)
    elif cuda_mathop_name == "log":
        expected = torch.log(a)
    else:
        expected = b_fastmath  # Just check compilation works

    torch.testing.assert_close(b_fastmath, expected, rtol=1e-3, atol=1e-3)
    print(f"✓ {mathop_name} numerical test passed")


246
247
248
@pytest.mark.parametrize(
    "name, func",
    [
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        ("exp", T.exp),
        ("exp2", T.exp2),
        ("exp10", T.exp10),
        ("log", T.log),
        ("log2", T.log2),
        ("log10", T.log10),
        ("sin", T.sin),
        ("cos", T.cos),
        ("tan", T.tan),
        ("sinh", T.sinh),
        ("cosh", T.cosh),
        ("tanh", T.tanh),
        ("atan", T.atan),
        ("sqrt", T.sqrt),
        ("rsqrt", T.rsqrt),
        ("erf", T.erf),
        ("floor", T.floor),
        ("ceil", T.ceil),
        ("trunc", T.trunc),
        ("round", T.round),
        ("nearbyint", T.nearbyint),
270
271
272
273
274
    ],
)
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath(name, func):
    """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
275
    run_single_arg_mathop_test(name, func, dtype=T.float32)
276
    print(f"✓ {name} test passed")
277
278


279
280
281
@pytest.mark.parametrize(
    "name, func",
    [
282
283
        ("pow", T.pow),
        ("fmod", T.fmod),
284
285
286
287
288
    ],
)
@tilelang.testing.requires_cuda
def test_two_arg_mathops_fastmath(name, func):
    """Test all two-argument mathops"""
289
    run_two_arg_mathop_test(name, func, dtype=T.float32)
290
291
292
293
294
295
296
297


@tilelang.testing.requires_cuda
def test_abs_maps_to_fabs():
    """Test that abs correctly maps to fabs"""
    run_abs_test()


298
299
300
@pytest.mark.parametrize(
    "name, func",
    [
301
302
303
304
305
306
307
308
        ("__exp", T.__exp),
        ("__exp10", T.__exp10),
        ("__log", T.__log),
        ("__log2", T.__log2),
        ("__log10", T.__log10),
        ("__tan", T.__tan),
        ("__cos", T.__cos),
        ("__sin", T.__sin),
309
310
311
312
313
    ],
)
@tilelang.testing.requires_cuda
def test_fastmath_versions(name, func):
    """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code"""
314
    run_fastmath_mathop_test(name, func, dtype=T.float32)
315
    print(f"✓ {name} test passed")
316
317
318
319


if __name__ == "__main__":
    tilelang.testing.main()