"docs/source/vscode:/vscode.git/clone" did not exist on "ea1b4ea7ca5fa7ab3d2094a70a35481f2a1036c2"
Unverified Commit 95c373f5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FastMath] Disable default TVM fastmath intrinsic dispatch and add explicit...

[FastMath] Disable default TVM fastmath intrinsic dispatch and add explicit fastmath op to invoke (#875)

* Add fast math operations for CUDA: exp, exp10, log, log2, log10, tan, cos, and sin (#865)

* Refactor fast math operation definitions for consistency and readability in CUDA code. Consolidated multiple definitions into single lines and improved formatting in related test files for better clarity.

* Remove unnecessary pass configurations for warp specialization and TMA lowering in fast math operation tests for CUDA. This simplifies the test setup while maintaining the focus on fast math functionality.

* Update fastmath tests to reflect that tl.* intrinsics generate no fastmath versions and disable cache in main execution.

* Fix formatting in fastmath test comments for clarity on tl.* intrinsics behavior.

* Add precision comparison tool for CUDA operations

This commit introduces a new Python script and CUDA source file for a precision comparison tool that evaluates the accuracy of various CUDA operations (including division, reciprocal, exponential, logarithmic, and trigonometric functions) across different implementations: CUDA Precise, CUDA Fast, Triton, Triton LibDevice, and TileLang. The tool generates test data, executes the operations, and summarizes the error statistics for each implementation against a double precision reference. Additionally, a README file is added to document the results of the comparisons for various operations.

* Add precision comparison tool for CUDA operations

This commit introduces a new precision comparison tool implemented in Python and CUDA, designed to evaluate the accuracy of various mathematical operations (division, reciprocal, exponential, logarithmic, trigonometric, square root, etc.) across different frameworks including CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang. The tool includes functionality for generating test data, executing operations, and summarizing error statistics for each implementation. Additionally, it provides a comprehensive README with error metrics for each operation tested.
parent 56f7494f
Subproject commit 7a71ee3411e49c3e05b1f1a910cf7f73adc7a5b2 Subproject commit 883e96b42ae0df40c2f7194cc932bbcd9d0c5627
=== div ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
Triton LibDevice vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
TileLang vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
PyTorch vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
Triton vs Double max abs: 2.605e-04, mean abs: 1.285e-07, max rel: 1.455e-07, mean rel: 3.175e-08
TileLang Fastmath vs Double max abs: 2.605e-04, mean abs: 1.285e-07, max rel: 1.455e-07, mean rel: 3.175e-08
CUDA Fast vs Double max abs: 2.605e-04, mean abs: 1.285e-07, max rel: 1.455e-07, mean rel: 3.175e-08
=== reciprocal ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
Triton LibDevice vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
TileLang vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
PyTorch vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
Triton vs Double max abs: 4.470e-05, mean abs: 4.886e-08, max rel: 9.699e-08, mean rel: 2.461e-08
TileLang Fastmath vs Double max abs: 4.470e-05, mean abs: 4.886e-08, max rel: 9.699e-08, mean rel: 2.461e-08
CUDA Fast vs Double max abs: 4.470e-05, mean abs: 4.886e-08, max rel: 9.699e-08, mean rel: 2.461e-08
=== exp ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
Triton LibDevice vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
TileLang vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
PyTorch vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
Triton vs Double max abs: 1.338e-05, mean abs: 5.023e-07, max rel: 2.641e-07, mean rel: 5.564e-08
TileLang Fastmath vs Double max abs: 1.338e-05, mean abs: 5.023e-07, max rel: 2.641e-07, mean rel: 5.564e-08
CUDA Fast vs Double max abs: 1.338e-05, mean abs: 5.023e-07, max rel: 2.641e-07, mean rel: 5.564e-08
=== log ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
Triton LibDevice vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
TileLang vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
PyTorch vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
Triton vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
TileLang Fastmath vs Double max abs: 9.087e-07, mean abs: 4.760e-08, max rel: 2.019e-02, mean rel: 3.183e-07
CUDA Fast vs Double max abs: 9.087e-07, mean abs: 4.760e-08, max rel: 2.019e-02, mean rel: 3.183e-07
=== sin ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
Triton LibDevice vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
TileLang vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
PyTorch vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
Triton vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
TileLang Fastmath vs Double max abs: 6.463e-07, mean abs: 1.251e-07, max rel: 7.111e-02, mean rel: 1.425e-06
CUDA Fast vs Double max abs: 6.463e-07, mean abs: 1.251e-07, max rel: 7.111e-02, mean rel: 1.425e-06
=== cos ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
Triton LibDevice vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
TileLang vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
PyTorch vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
Triton vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
TileLang Fastmath vs Double max abs: 4.006e-07, mean abs: 9.249e-08, max rel: 5.275e-02, mean rel: 7.307e-07
CUDA Fast vs Double max abs: 4.006e-07, mean abs: 9.249e-08, max rel: 5.275e-02, mean rel: 7.307e-07
=== sqrt ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
Triton LibDevice vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
TileLang vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
PyTorch vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
Triton vs Double max abs: 1.114e-07, mean abs: 2.947e-08, max rel: 9.962e-08, mean rel: 2.291e-08
TileLang Fastmath vs Double max abs: 1.114e-07, mean abs: 2.947e-08, max rel: 9.962e-08, mean rel: 2.291e-08
CUDA Fast vs Double max abs: 1.114e-07, mean abs: 2.947e-08, max rel: 9.962e-08, mean rel: 2.291e-08
=== tanh ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
Triton LibDevice vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
TileLang vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
PyTorch vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
Triton vs Double max abs: 2.293e-07, mean abs: 3.965e-08, max rel: 6.204e-04, mean rel: 1.100e-07
TileLang Fastmath vs Double max abs: 7.826e-06, mean abs: 1.384e-06, max rel: 1.081e-05, mean rel: 1.906e-06
CUDA Fast vs Double max abs: 7.826e-06, mean abs: 1.384e-06, max rel: 1.081e-05, mean rel: 1.906e-06
=== rsqrt ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
Triton LibDevice vs Double max abs: 9.535e-07, mean abs: 2.199e-08, max rel: 5.960e-08, mean rel: 2.315e-08
TileLang vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
PyTorch vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
Triton vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
TileLang Fastmath vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
CUDA Fast vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
=== inv_sqrt ===
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
------------------------------------------------------------------------------------------
FP32 Precise vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
Triton LibDevice vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
TileLang vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
PyTorch vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
Triton vs Double max abs: 2.876e-06, mean abs: 3.443e-08, max rel: 1.536e-07, mean rel: 3.503e-08
TileLang Fastmath vs Double max abs: 2.876e-06, mean abs: 3.443e-08, max rel: 1.536e-07, mean rel: 3.503e-08
CUDA Fast vs Double max abs: 2.876e-06, mean abs: 3.171e-08, max rel: 1.250e-07, mean rel: 3.211e-08
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ruff: noqa
"""
Precision comparison tool for CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang operations.
"""
import os
import argparse
import sys
from typing import Dict, Optional, Tuple
import torch
from torch.utils.cpp_extension import load
import triton
import triton.language as tl
from triton.language.extra import libdevice
import tilelang
import tilelang.language as T
tilelang.disable_cache()
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
# GPU configuration setup
target = determine_target(return_object=True)
compute_version = nvcc.get_target_compute_version(target)
major, minor = nvcc.parse_compute_version(compute_version)
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
# Operator enumeration - must match OperatorType in C++
OP_NAMES: Dict[int, str] = {
0: "div",
1: "reciprocal",
2: "exp",
3: "log",
4: "sin",
5: "cos",
6: "sqrt",
7: "tanh",
8: "rsqrt",
9: "inv_sqrt"
}
# Block sizes for kernels
TRITON_BLOCK_SIZE = 1024
TILELANG_BLOCK_M = 32
TILELANG_BLOCK_N = 32
TILELANG_THREADS = 128
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Precision comparison tool for various CUDA implementations")
parser.add_argument("--n", type=int, default=1000000, help="Number of elements to test")
parser.add_argument("--low", type=float, default=-4.0, help="Lower bound for random values")
parser.add_argument("--high", type=float, default=4.0, help="Upper bound for random values")
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
return parser.parse_args()
def initialize_cuda() -> torch.nn.Module:
"""Initialize CUDA and load the custom operators module."""
if not torch.cuda.is_available():
print("CUDA is required", file=sys.stderr)
sys.exit(1)
return load(
name="cuda_ops",
sources=["cuda_ops.cu"],
extra_cuda_cflags=[] # No fast_math flags
)
# Initialize global variables
args = parse_arguments()
torch.manual_seed(args.seed)
mod = initialize_cuda()
device = torch.device("cuda")
n = args.n
low, high = args.low, args.high
# Triton kernels
@triton.jit
def triton_binary_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
"""Standard Triton kernel for binary operations (div)."""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
result = x / y # Division operation
tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def triton_libdevice_binary_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
"""LibDevice Triton kernel for binary operations (div)."""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
result = libdevice.div_rn(x, y) # Round to nearest
tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def tl_tanh(x):
"""Triton tanh implementation using sigmoid."""
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_SIZE: tl.constexpr):
"""Standard Triton kernel for unary operations."""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
if op_id == 1: # reciprocal
result = 1.0 / x
elif op_id == 2: # exp
result = tl.exp(x)
elif op_id == 3: # log
result = tl.log(x)
elif op_id == 4: # sin
result = tl.sin(x)
elif op_id == 5: # cos
result = tl.cos(x)
elif op_id == 6: # sqrt
result = tl.sqrt(x)
elif op_id == 7: # tanh
result = tl_tanh(x)
elif op_id == 8: # rsqrt
result = tl.rsqrt(x)
elif op_id == 9: # inv_sqrt
result = 1.0 / tl.sqrt(x)
else:
result = x # Default case
tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
"""LibDevice Triton kernel for unary operations."""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
if op_id == 1: # reciprocal
result = libdevice.rcp_rn(x)
elif op_id == 2: # exp
result = libdevice.exp(x)
elif op_id == 3: # log
result = libdevice.log(x)
elif op_id == 4: # sin
result = libdevice.sin(x)
elif op_id == 5: # cos
result = libdevice.cos(x)
elif op_id == 6: # sqrt
result = libdevice.sqrt_rn(x) # Round to nearest
elif op_id == 7: # tanh
result = libdevice.tanh(x)
elif op_id == 8: # rsqrt
result = libdevice.rsqrt_rn(x)
elif op_id == 9: # inv_sqrt
result = libdevice.rcp_rn(libdevice.sqrt_rn(x))
else:
result = x # Default case
tl.store(out_ptr + offsets, result, mask=mask)
# TileLang kernel generators
def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool = False):
"""Generate TileLang unary operation kernel."""
@T.prim_func
def tilelang_unary_kernel(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(
T.ceildiv(N, TILELANG_BLOCK_N),
T.ceildiv(M, TILELANG_BLOCK_M),
threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
x = A[row, col]
if op_id == 1: # reciprocal
B[row, col] = 1.0 / x
elif op_id == 2: # exp
B[row, col] = T.exp(x)
elif op_id == 3: # log
B[row, col] = T.log(x)
elif op_id == 4: # sin
B[row, col] = T.sin(x)
elif op_id == 5: # cos
B[row, col] = T.cos(x)
elif op_id == 6: # sqrt
B[row, col] = T.sqrt(x)
elif op_id == 7: # tanh
B[row, col] = T.tanh(x)
elif op_id == 8: # rsqrt
B[row, col] = T.rsqrt(x)
elif op_id == 9: # inv_sqrt
B[row, col] = 1.0 / T.sqrt(x)
else:
B[row, col] = x # Default case
return tilelang_unary_kernel
def make_tilelang_binary_kernel(M: int, N: int):
"""Generate TileLang binary operation kernel (division)."""
@T.prim_func
def tilelang_binary_kernel(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
C: T.Tensor((M, N), "float32"),
):
with T.Kernel(
T.ceildiv(N, TILELANG_BLOCK_N),
T.ceildiv(M, TILELANG_BLOCK_M),
threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
x = A[row, col]
y = B[row, col]
C[row, col] = x / y # Division operation
return tilelang_binary_kernel
def tilelang_op(x: torch.Tensor,
op_id: int,
y: Optional[torch.Tensor] = None,
use_fastmath: bool = False) -> torch.Tensor:
"""TileLang operation interface."""
assert x.is_cuda
# Reshape 1D tensor to 2D for TileLang kernels
original_shape = x.shape
if len(x.shape) == 1:
x = x.view(1, -1)
if y is not None:
y = y.view(1, -1)
M, N = x.shape
if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
kernel_func = make_tilelang_binary_kernel(M, N)
kernel = tilelang.compile(
kernel_func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath,
})
out = kernel(x, y)
else: # Unary operation
kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath)
kernel = tilelang.compile(
kernel_func,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath,
})
out = kernel(x)
# Restore original shape
return out.view(original_shape)
def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Standard Triton operation interface."""
assert x.is_cuda
out = torch.empty_like(x)
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
triton_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE)
else: # Unary operation
triton_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE)
return out
def triton_libdevice_op(x: torch.Tensor,
op_id: int,
y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""LibDevice Triton operation interface."""
assert x.is_cuda
out = torch.empty_like(x)
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
triton_libdevice_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE)
else: # Unary operation
triton_libdevice_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE)
return out
def get_pytorch_reference(x: torch.Tensor,
op_id: int,
y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Get PyTorch reference implementation for the given operation."""
if op_id == 0:
assert y is not None, "Division requires second operand"
return x / y
elif op_id == 1:
return torch.reciprocal(x)
elif op_id == 2:
return torch.exp(x)
elif op_id == 3:
return torch.log(x)
elif op_id == 4:
return torch.sin(x)
elif op_id == 5:
return torch.cos(x)
elif op_id == 6:
return torch.sqrt(x)
elif op_id == 7:
return torch.tanh(x)
elif op_id == 8:
return torch.rsqrt(x)
elif op_id == 9:
return 1 / torch.sqrt(x)
else:
raise ValueError(f"Unknown op_id: {op_id}")
def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.Tensor) -> None:
"""Summarize and print error statistics for an implementation."""
if output is None:
print(f"{tag:<32} FAILED")
return
# Convert results to double precision for error calculation
output_double = output.double()
reference_double = reference.double() if reference.dtype != torch.float64 else reference
abs_err = (output_double - reference_double).abs()
rel_err = abs_err / (reference_double.abs().clamp_min(1e-30))
print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}")
# Precision comparison function
def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> None:
name = OP_NAMES[op_id]
print(f"\n=== {name} ===")
# Create double precision version of input data as reference standard
x_double = x.double()
y_double = y.double() if y is not None else None
# Double CUDA Precise as golden standard
ref_double = torch.empty_like(x_double)
mod.launch_double_precise_operator(x_double, y_double, ref_double, op_id)
# CUDA Precise (FP32)
ref_float = torch.empty_like(x)
mod.launch_precise_operator(x, y, ref_float, op_id)
# CUDA Fast
result_fast = torch.empty_like(ref_float)
mod.launch_fast_operator(x, y, result_fast, op_id)
# PyTorch reference
torch_ref = get_pytorch_reference(x, op_id, y)
# Test implementations with error handling
implementations = [
("Standard Triton", lambda: triton_op(x, op_id, y)),
("LibDevice Triton", lambda: triton_libdevice_op(x, op_id, y)),
("TileLang Standard", lambda: tilelang_op(x, op_id, y, use_fastmath=False)),
("TileLang Fastmath", lambda: tilelang_op(x, op_id, y, use_fastmath=True)),
]
results = {}
for name, impl_func in implementations:
try:
results[name] = impl_func()
except Exception as e:
print(f"{name} failed: {e}")
results[name] = None
# Print comparison header
print(
f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}"
)
print("-" * 90)
# Compare all implementations against double precision reference
comparisons = [
("FP32 Precise vs Double", ref_float),
("Triton LibDevice vs Double", results.get("LibDevice Triton")),
("TileLang vs Double", results.get("TileLang Standard")),
("PyTorch vs Double", torch_ref),
("Triton vs Double", results.get("Standard Triton")),
("TileLang Fastmath vs Double", results.get("TileLang Fastmath")),
("CUDA Fast vs Double", result_fast),
]
for tag, output in comparisons:
summarize_error(tag, output, ref_double)
def generate_test_data(op_id: int, n: int, device: torch.device, low: float,
high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Generate appropriate test data for each operation."""
if op_id == 0: # Division
x = torch.empty(n, device=device).uniform_(low, high)
y = torch.empty(n, device=device).uniform_(1e-3, high) # Avoid division by zero
return x, y
elif op_id in (3, 6): # log and sqrt need positive inputs
x = torch.empty(n, device=device).uniform_(1e-3, high)
return x, None
elif op_id in (8, 9): # rsqrt and inv_sqrt need positive inputs (use consistent data)
x = torch.empty(n, device=device).uniform_(1e-3, high)
return x, None
elif op_id == 1: # reciprocal - avoid values close to zero
x = torch.empty(n, device=device).uniform_(1e-3, high)
return x, None
else: # General case
x = torch.empty(n, device=device).uniform_(low, high)
return x, None
def main() -> None:
"""Main execution function."""
print(
"Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang"
)
print("=" * 90)
for op_id in range(len(OP_NAMES)):
try:
x, y = generate_test_data(op_id, n, device, low, high)
compare(op_id, x, y)
except Exception as e:
print(f"Error in {OP_NAMES[op_id]}: {e}")
continue
if __name__ == "__main__":
main()
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
enum OperatorType {
OP_DIV,
OP_RECIPROCAL,
OP_EXP,
OP_LOG,
OP_SIN,
OP_COS,
OP_SQRT,
OP_TANH,
OP_RSQRT,
OP_INV_SQRT
};
// ================= 精确版本 device 运算符 =================
__device__ __forceinline__ float precise_div(float a, float b) {
return a / b;
}
__device__ __forceinline__ float precise_reciprocal(float x) {
return 1.0f / x;
}
__device__ __forceinline__ float precise_exp(float x) {
return expf(x);
}
__device__ __forceinline__ float precise_log(float x) {
return logf(x);
}
__device__ __forceinline__ float precise_sin(float x) {
return sinf(x);
}
__device__ __forceinline__ float precise_cos(float x) {
return cosf(x);
}
__device__ __forceinline__ float precise_sqrt(float x) {
return sqrtf(x);
}
__device__ __forceinline__ float precise_tanh(float x) {
return tanhf(x);
}
__device__ __forceinline__ float precise_rsqrt(float x) {
return rsqrtf(x);
}
__device__ __forceinline__ float precise_inv_sqrt(float x) {
return 1.0f / sqrtf(x);
}
// ================= double 精确版本 device 运算符 =================
__device__ __forceinline__ double double_precise_div(double a, double b) {
return a / b;
}
__device__ __forceinline__ double double_precise_reciprocal(double x) {
return 1.0 / x;
}
__device__ __forceinline__ double double_precise_exp(double x) {
return exp(x);
}
__device__ __forceinline__ double double_precise_log(double x) {
return log(x);
}
__device__ __forceinline__ double double_precise_sin(double x) {
return sin(x);
}
__device__ __forceinline__ double double_precise_cos(double x) {
return cos(x);
}
__device__ __forceinline__ double double_precise_sqrt(double x) {
return sqrt(x);
}
__device__ __forceinline__ double double_precise_tanh(double x) {
return tanh(x);
}
__device__ __forceinline__ double double_precise_rsqrt(double x) {
return 1.0 / sqrt(x);
}
__device__ __forceinline__ double double_precise_inv_sqrt(double x) {
return 1.0 / sqrt(x);
}
// ================= 快速近似版本 device 运算符 =================
__device__ __forceinline__ float fast_div(float a, float b) {
return __fdividef(a, b);
}
__device__ __forceinline__ float fast_reciprocal(float x) {
float ret;
asm volatile("rcp.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
__device__ __forceinline__ float fast_exp(float x) {
return __expf(x);
}
__device__ __forceinline__ float fast_log(float x) {
return __logf(x);
}
__device__ __forceinline__ float fast_sin(float x) {
return __sinf(x);
}
__device__ __forceinline__ float fast_cos(float x) {
return __cosf(x);
}
__device__ __forceinline__ float fast_sqrt(float x) {
float ret;
asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
__device__ __forceinline__ float fast_tanh(float x) {
return __tanhf(x);
}
__device__ __forceinline__ float fast_rsqrt(float x) {
// return rsqrtf(x);
float ret;
asm volatile("rsqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
__device__ __forceinline__ float fast_inv_sqrt(float x) {
float ret;
asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x));
return 1.0f / ret;
}
// ================= 精确版本 kernel =================
__global__ void precise_operator_kernel(const float* x, const float* y, float* result, int64_t n, OperatorType op_type) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float a = x[i];
float b = (y != nullptr) ? y[i] : 0.0f;
float r = 0.0f;
switch (op_type) {
case OP_DIV: r = precise_div(a, b); break;
case OP_RECIPROCAL: r = precise_reciprocal(a); break;
case OP_EXP: r = precise_exp(a); break;
case OP_LOG: r = precise_log(a); break;
case OP_SIN: r = precise_sin(a); break;
case OP_COS: r = precise_cos(a); break;
case OP_SQRT: r = precise_sqrt(a); break;
case OP_TANH: r = precise_tanh(a); break;
case OP_RSQRT: r = precise_rsqrt(a); break;
case OP_INV_SQRT: r = precise_inv_sqrt(a); break;
}
result[i] = r;
}
}
// ================= double 精确版本 kernel =================
__global__ void double_precise_operator_kernel(const double* x, const double* y, double* result, int64_t n, OperatorType op_type) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
double a = x[i];
double b = (y != nullptr) ? y[i] : 0.0;
double r = 0.0;
switch (op_type) {
case OP_DIV: r = double_precise_div(a, b); break;
case OP_RECIPROCAL: r = double_precise_reciprocal(a); break;
case OP_EXP: r = double_precise_exp(a); break;
case OP_LOG: r = double_precise_log(a); break;
case OP_SIN: r = double_precise_sin(a); break;
case OP_COS: r = double_precise_cos(a); break;
case OP_SQRT: r = double_precise_sqrt(a); break;
case OP_TANH: r = double_precise_tanh(a); break;
case OP_RSQRT: r = double_precise_rsqrt(a); break;
case OP_INV_SQRT: r = double_precise_inv_sqrt(a); break;
}
result[i] = r;
}
}
// ================= 快速版本 kernel =================
__global__ void fast_operator_kernel(const float* x, const float* y, float* result, int64_t n, OperatorType op_type) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float a = x[i];
float b = (y != nullptr) ? y[i] : 0.0f;
float r = 0.0f;
switch (op_type) {
case OP_DIV: r = fast_div(a, b); break;
case OP_RECIPROCAL: r = fast_reciprocal(a); break;
case OP_EXP: r = fast_exp(a); break;
case OP_LOG: r = fast_log(a); break;
case OP_SIN: r = fast_sin(a); break;
case OP_COS: r = fast_cos(a); break;
case OP_SQRT: r = fast_sqrt(a); break;
case OP_TANH: r = fast_tanh(a); break;
case OP_RSQRT: r = fast_rsqrt(a); break;
case OP_INV_SQRT: r = fast_inv_sqrt(a); break;
}
result[i] = r;
}
}
// 精确版本
void launch_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
int64_t n = x.numel();
int threads = 256;
int blocks = (n + threads - 1) / threads;
const float* y_ptr = nullptr;
if (y.has_value()) {
y_ptr = y.value().data_ptr<float>();
}
precise_operator_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), y_ptr, result.data_ptr<float>(), n, static_cast<OperatorType>(op_type)
);
}
// double 精确版本
void launch_double_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
int64_t n = x.numel();
int threads = 256;
int blocks = (n + threads - 1) / threads;
const double* y_ptr = nullptr;
if (y.has_value()) {
y_ptr = y.value().data_ptr<double>();
}
double_precise_operator_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<double>(), y_ptr, result.data_ptr<double>(), n, static_cast<OperatorType>(op_type)
);
}
// 快速版本
void launch_fast_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
int64_t n = x.numel();
int threads = 256;
int blocks = (n + threads - 1) / threads;
const float* y_ptr = nullptr;
if (y.has_value()) {
y_ptr = y.value().data_ptr<float>();
}
fast_operator_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), y_ptr, result.data_ptr<float>(), n, static_cast<OperatorType>(op_type)
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("launch_precise_operator", &launch_precise_operator, "CUDA Precise Operator",
py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type"));
m.def("launch_double_precise_operator", &launch_double_precise_operator, "CUDA Double Precise Operator",
py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type"));
m.def("launch_fast_operator", &launch_fast_operator, "CUDA Fast Operator",
py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type"));
}
\ No newline at end of file
...@@ -41,6 +41,31 @@ DataType cuTensorMapType() { return DataType::UInt(8, 128); } ...@@ -41,6 +41,31 @@ DataType cuTensorMapType() { return DataType::UInt(8, 128); }
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
// fast math related op
TIR_DEFINE_TL_BUILTIN(__exp).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(__exp10).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(__log).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(__log2).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(__log10).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(__tan).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -89,6 +89,16 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment"; ...@@ -89,6 +89,16 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
*/ */
DataType cuTensorMapType(); DataType cuTensorMapType();
// fast math related op
TVM_DLL const Op &__exp();
TVM_DLL const Op &__exp10();
TVM_DLL const Op &__log();
TVM_DLL const Op &__log2();
TVM_DLL const Op &__log10();
TVM_DLL const Op &__tan();
TVM_DLL const Op &__cos();
TVM_DLL const Op &__sin();
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
......
...@@ -21,6 +21,79 @@ namespace tvm { ...@@ -21,6 +21,79 @@ namespace tvm {
namespace codegen { namespace codegen {
using namespace tvm::tl::codegen; using namespace tvm::tl::codegen;
struct CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
switch (t.bits()) {
case 64:
return name;
case 32:
return name + 'f';
case 16: {
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
return "hrint";
} else {
return "h" + name;
}
}
default:
return "";
}
} else if (t.is_bfloat16()) {
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
return "hrint";
} else {
return "h" + name;
}
} else if (t.is_int() || t.is_uint()) {
switch (t.bits()) {
case 32:
return "__" + name;
case 64:
return "__" + name + "ll";
default:
return "";
}
}
return "";
}
};
struct CUDAFastMath : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
return CUDAMath::operator()(t, name);
}
return "";
}
};
struct CUDAFastMathTan : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
switch (t.bits()) {
case 64:
return name;
// `__tanf` seems to produce some values too deviant from numpy tan
// version. So, let's use just `tanf` instead.
case 32:
return name + 'f';
case 16:
return 'h' + name;
default:
return "";
}
}
return "";
}
};
static std::string GetFP8Type(DataType type) { static std::string GetFP8Type(DataType type) {
std::stringstream stream; std::stringstream stream;
int32_t lanes = type.lanes(); int32_t lanes = type.lanes();
...@@ -1628,6 +1701,38 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1628,6 +1701,38 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
op->args, true, os); op->args, true, os);
} else if (op->op.same_as(tl::tl_shuffle_elect())) { } else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else if (op->op.same_as(tl::__exp())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "exp");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__exp10())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "exp10");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__log())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "log");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__log2())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "log2");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__log10())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "log10");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__tan())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "tan");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__cos())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "cos");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__sin())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "sin");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else { } else {
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
......
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import re
def get_mathop_lines(source, mathop_name):
"""Extract lines containing the mathop from CUDA source for debugging"""
lines = source.split('\n')
relevant_lines = []
for i, line in enumerate(lines):
if mathop_name in line and ('(' in line):
# Include some context
start = max(0, i - 1)
end = min(len(lines), i + 2)
relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)])
relevant_lines.append("---")
return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
"""Check source for fastmath/non-fastmath versions"""
fastmath_pattern = rf"__({mathop_name}f?)\b"
non_fastmath_pattern = rf"(?<!__)({mathop_name}f?)\b"
fastmath_matches = re.findall(fastmath_pattern, source)
non_fastmath_matches = re.findall(non_fastmath_pattern, source)
print(
f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls"
)
if len(fastmath_matches) > 0:
print(f"Fastmath calls found: {fastmath_matches}")
if len(non_fastmath_matches) > 0:
print(f"Non-fastmath calls found: {non_fastmath_matches}")
print(f"Source preview for {mathop_name}:")
print(get_mathop_lines(source, mathop_name))
if expect_fastmath:
assert len(fastmath_matches) > 0, "Expected fastmath calls but found none"
print(f"✓ {mathop_name} correctly uses fastmath versions")
else:
assert len(fastmath_matches) == 0, f"Found unexpected fastmath calls: {fastmath_matches}"
assert len(non_fastmath_matches) > 0, f"No {mathop_name} calls found"
print(f"✓ {mathop_name} correctly uses non-fastmath versions")
def check_non_fastmath_usage(source, mathop_name):
"""Check that source uses non-fastmath versions (no __ prefix)"""
check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
"""
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
main,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} ===")
print("FAST_MATH=False:")
# Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j])
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
main,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
main,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
source_fastmath = kernel_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} (two args) ===")
print("FAST_MATH=False:")
check_non_fastmath_usage(source_no_fastmath, mathop_name)
print("FAST_MATH=True:")
check_non_fastmath_usage(source_fastmath, mathop_name)
# Test numerical correctness
torch_dtype = getattr(torch, dtype)
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
b = torch.randn(M, N, device="cuda", dtype=torch_dtype)
# Ensure positive values for functions that need them
if mathop_name == "pow":
a = torch.abs(a) + 0.1
b = torch.clamp(b, -3, 3) # Limit exponent range
elif mathop_name == "fmod":
b = torch.abs(b) + 0.1 # Avoid division by zero
c_no_fastmath = kernel_no_fastmath(a, b)
c_fastmath = kernel_fastmath(a, b)
# Both should produce similar results
torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3)
print(f"✓ {mathop_name} numerical test passed")
def run_abs_test():
"""Test that abs correctly maps to fabs (not __fabsf) in generated CUDA code"""
M, N = 128, 128
block_M, block_N = 32, 32
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = T.abs(A[by * block_M + i, bx * block_N + j])
kernel = tilelang.compile(
main,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
source = kernel.get_kernel_source()
print("\n=== Testing abs (maps to fabs) ===")
check_non_fastmath_usage(source, "fabs")
# Test numerical correctness
a = torch.randn(M, N, device="cuda", dtype=torch.float32)
b = kernel(a)
expected = torch.abs(a)
torch.testing.assert_close(b, expected, rtol=1e-5, atol=1e-5)
print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
main,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
source_fastmath = kernel_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} (fastmath version) ===")
print("FAST_MATH=True:")
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name = mathop_name.lstrip('_')
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness
torch_dtype = getattr(torch, dtype)
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
# Ensure positive values for functions that need them
if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]:
a = torch.abs(a) + 0.1
b_fastmath = kernel_fastmath(a)
# Compare with reference implementation
if cuda_mathop_name == "exp":
expected = torch.exp(a)
elif cuda_mathop_name == "log":
expected = torch.log(a)
else:
expected = b_fastmath # Just check compilation works
torch.testing.assert_close(b_fastmath, expected, rtol=1e-3, atol=1e-3)
print(f"✓ {mathop_name} numerical test passed")
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath():
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
# Based on test results, our tl.* intrinsics actually generate
# no fastmath versions
# This appears to be the intended behavior
single_arg_mathops = [
("exp", T.exp),
("exp2", T.exp2),
("exp10", T.exp10),
("log", T.log),
("log2", T.log2),
("log10", T.log10),
("sin", T.sin),
("cos", T.cos),
("tan", T.tan),
("sinh", T.sinh),
("cosh", T.cosh),
("tanh", T.tanh),
("atan", T.atan),
("sqrt", T.sqrt),
("rsqrt", T.rsqrt),
("erf", T.erf),
("floor", T.floor),
("ceil", T.ceil),
("trunc", T.trunc),
("round", T.round),
("nearbyint", T.nearbyint),
]
for name, func in single_arg_mathops:
run_single_arg_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
@tilelang.testing.requires_cuda
def test_two_arg_mathops_fastmath():
"""Test all two-argument mathops"""
# Two argument mathops
two_arg_mathops = [
("pow", T.pow),
("fmod", T.fmod),
]
for name, func in two_arg_mathops:
run_two_arg_mathop_test(name, func, dtype="float32")
@tilelang.testing.requires_cuda
def test_abs_maps_to_fabs():
"""Test that abs correctly maps to fabs"""
run_abs_test()
@tilelang.testing.requires_cuda
def test_fastmath_versions():
"""Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code"""
# Test fastmath versions
fastmath_mathops = [
("__exp", T.__exp),
("__exp10", T.__exp10),
("__log", T.__log),
("__log2", T.__log2),
("__log10", T.__log10),
("__tan", T.__tan),
("__cos", T.__cos),
("__sin", T.__sin),
]
for name, func in fastmath_mathops:
run_fastmath_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
if __name__ == "__main__":
tilelang.disable_cache()
tilelang.testing.main()
...@@ -26,6 +26,7 @@ from .parallel import Parallel # noqa: F401 ...@@ -26,6 +26,7 @@ from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401 from .pipeline import Pipelined # noqa: F401
from .persistent import Persistent # noqa: F401 from .persistent import Persistent # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401 from .frame import has_let_value, get_let_value # noqa: F401
from .fastmath import * # noqa: F401
from .kernel import ( from .kernel import (
Kernel, # noqa: F401 Kernel, # noqa: F401
KernelLaunchFrame, # noqa: F401 KernelLaunchFrame, # noqa: F401
......
from tvm import tir
def __log(x):
"""Calculate log(x) with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log"), x)
def __log2(x):
"""Calculate log2(x) with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log2"), x)
def __log10(x):
"""Calculate log10(x) with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__log10"), x)
def __tan(x):
"""Calculate tan(x) with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__tan"), x)
def __cos(x):
"""Calculate cos(x) with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__cos"), x)
def __sin(x):
"""Calculate sin(x) with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__sin"), x)
def __exp10(x):
"""Calculate 10**x with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__exp10"), x)
def __exp(x):
"""Calculate 2**x with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
x = tir.convert(x)
return tir.call_intrin(x.dtype, tir.op.Op.get("tl.__exp"), x)
__all__ = [
"__log", # noqa: F401
"__log2", # noqa: F401
"__log10", # noqa: F401
"__tan", # noqa: F401
"__cos", # noqa: F401
"__sin", # noqa: F401
"__exp10", # noqa: F401
"__exp", # noqa: F401
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment