Commit d0434c3e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Enable runtime tensor data type validation (#146)

* Fix debug print buffer template for unsigned char type

- Update debug_print_buffer_value template specialization for unsigned char
- Modify test_tilelang_debug_print.py to include additional dtype tests
- Add test case for uint8 dtype in debug print buffer function

* Refactor debug print buffer template formatting for unsigned char

- Improve code formatting for debug_print_buffer_value template specialization
- Adjust line breaks and indentation for better readability
- Maintain consistent code style with other template specializations

* Extract map_torch_type utility function to tilelang.utils.tensor

- Move map_torch_type function from multiple test files to a centralized location
- Import map_torch_type from tilelang.utils.tensor in kernel test files
- Improve code reusability by creating a shared utility function for type mapping

* Add buffer dtype mapping for Cython kernel adapter

- Introduce buffer_dtype_map in CythonKernelAdapter to track buffer variable dtypes
- Add _process_buffer_dtype method to extract dtype information from TIR function
- Update CythonKernelWrapper to support setting and validating buffer dtypes
- Enhance type checking during kernel execution with dtype verification
- Improve logging message for Cython JIT adapter compilation

* Add static shape mapping for Cython kernel adapter

- Introduce static_shape_map in CythonKernelAdapter to track buffer variable static shapes
- Add _process_static_shape method to extract static shape information from TIR function
- Update CythonKernelWrapper to support setting and validating static shapes
- Enhance type checking during kernel execution with static shape verification

* Add Multi-Head Attention (MHA) Backward Pass Test for TileLang Kernel

- Implement comprehensive test for Multi-Head Attention backward pass
- Support both causal and non-causal attention scenarios
- Add reference implementation for comparing kernel outputs
- Test different batch sizes, head counts, sequence lengths, and head dimensions
- Verify forward and backward pass correctness using torch.testing.assert_close

* Set random seed for MHA backward pass test

- Add random seed initialization for consistent test reproducibility
- Use tilelang.testing.set_random_seed(42) to ensure deterministic test results
parent bb60f6ce
...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout ...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype) accum_dtype = map_torch_type(accum_dtype)
......
...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout ...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype) accum_dtype = map_torch_type(accum_dtype)
......
...@@ -8,6 +8,7 @@ from tvm import DataType ...@@ -8,6 +8,7 @@ from tvm import DataType
import tilelang.language as T import tilelang.language as T
from tilelang import JITKernel from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify from tilelang.transform.simplify import apply_simplify
from tilelang.utils.tensor import map_torch_type
from typing import Optional from typing import Optional
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -131,16 +132,6 @@ def evaluate_gemv_simt( ...@@ -131,16 +132,6 @@ def evaluate_gemv_simt(
kernel = JITKernel(program, target="cuda") kernel = JITKernel(program, target="cuda")
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype) accum_dtype = map_torch_type(accum_dtype)
......
...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout ...@@ -12,6 +12,7 @@ from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype) accum_dtype = map_torch_type(accum_dtype)
......
...@@ -8,6 +8,7 @@ from tvm import DataType ...@@ -8,6 +8,7 @@ from tvm import DataType
import tilelang.language as T import tilelang.language as T
from tilelang import JITKernel from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify from tilelang.transform.simplify import apply_simplify
from tilelang.utils.tensor import map_torch_type
from typing import Optional from typing import Optional
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -131,16 +132,6 @@ def evaluate_gemv_simt( ...@@ -131,16 +132,6 @@ def evaluate_gemv_simt(
kernel = JITKernel(program, target="cuda") kernel = JITKernel(program, target="cuda")
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype) accum_dtype = map_torch_type(accum_dtype)
......
...@@ -10,6 +10,8 @@ import tilelang.language as T ...@@ -10,6 +10,8 @@ import tilelang.language as T
import tilelang.testing import tilelang.testing
tilelang.testing.set_random_seed(42)
def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
...@@ -302,10 +304,10 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal): ...@@ -302,10 +304,10 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
dK_ref, K.grad = K.grad.clone(), None dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
def test_mha_bwd(): def test_mha_bwd():
......
...@@ -13,8 +13,9 @@ from tilelang.jit.adapter.wrapper import TLWrapper ...@@ -13,8 +13,9 @@ from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
import torch
import sys import sys
import sysconfig import sysconfig
import hashlib import hashlib
...@@ -89,7 +90,7 @@ with open(cython_wrapper_path, "r") as f: ...@@ -89,7 +90,7 @@ with open(cython_wrapper_path, "r") as f:
logger.debug("Cython jit adapter is up to date, no need to compile...") logger.debug("Cython jit adapter is up to date, no need to compile...")
need_compile = False need_compile = False
else: else:
logger.info("Cython jit adapter is out of date, need to compile...") logger.info("Cython jit adapter is out of date, need to recompile...")
else: else:
logger.info("No cached version found for cython jit adapter, need to compile...") logger.info("No cached version found for cython jit adapter, need to compile...")
...@@ -135,6 +136,13 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -135,6 +136,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
# Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
# Maps buffer variables to their corresponding static shapes
# {
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16)
# }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
def __init__(self, def __init__(self,
rt_mod, rt_mod,
...@@ -163,6 +171,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -163,6 +171,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.ir_module = func_or_mod self.ir_module = func_or_mod
self.dynamic_symbolic_map = self._process_dynamic_symbolic() self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype()
self.static_shape_map = self._process_static_shape()
self.target = Target.canon_target(determine_target(target)) self.target = Target.canon_target(determine_target(target))
self.verbose = verbose self.verbose = verbose
...@@ -182,12 +192,14 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -182,12 +192,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
raise Exception( raise Exception(
f"Failed to initialize the compiled library for {self.target}: {e}") from e f"Failed to initialize the compiled library for {self.target}: {e}") from e
self.cython_wrapper = CythonKernelWrapper(self.dynamic_symbolic_map, self.result_idx, self.cython_wrapper = CythonKernelWrapper(self.result_idx, self.params, self.lib)
self.params, self.lib) self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map)
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map)
self._post_init() self._post_init()
def _process_dynamic_symbolic(self): def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension) Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
...@@ -205,6 +217,43 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -205,6 +217,43 @@ class CythonKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[shape] = (i, j) dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
"""Extract information about buffer dtypes from the TIR function.
Maps buffer variables to their corresponding dtypes.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
buffer_dtype_map = {}
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
name, dtype = buffer.name, buffer.dtype
buffer_dtype_map[name] = (i, map_torch_type(dtype))
return buffer_dtype_map
def _process_static_shape(self) -> Dict[tir.Var, List[Tuple[int, int]]]:
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
static_shape_map = {}
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
name = buffer.name
shape = buffer.shape
static_shape = []
for j, s in enumerate(shape):
if isinstance(s, tir.IntImm):
static_shape.append((j, s.value))
static_shape_map[name] = (i, static_shape)
return static_shape_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
......
...@@ -12,17 +12,30 @@ cdef class CythonKernelWrapper: ...@@ -12,17 +12,30 @@ cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference # Class attributes to store kernel configuration and library reference
cdef: cdef:
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes
list result_idx # Indices of output tensors in the params list list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs) list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel object lib # Reference to the compiled library containing the kernel
def __cinit__(self, dynamic_symbolic_map, result_idx, params, lib): def __cinit__(self, result_idx, params, lib):
# Initialize wrapper with kernel configuration # Initialize wrapper with kernel configuration
self.dynamic_symbolic_map = dynamic_symbolic_map
self.result_idx = result_idx self.result_idx = result_idx
self.params = params self.params = params
self.lib = lib self.lib = lib
def set_dynamic_symbolic_map(self, dynamic_symbolic_map):
self.dynamic_symbolic_map = dynamic_symbolic_map
return self
def set_buffer_dtype_map(self, buffer_dtype_map):
self.buffer_dtype_map = buffer_dtype_map
return self
def set_static_shape_map(self, static_shape_map):
self.static_shape_map = static_shape_map
return self
cpdef forward(self, list inputs, int64_t stream = -1): cpdef forward(self, list inputs, int64_t stream = -1):
# Validate input dimensions and prepare for kernel execution # Validate input dimensions and prepare for kernel execution
cdef int total_params = len(self.params) cdef int total_params = len(self.params)
...@@ -69,6 +82,17 @@ cdef class CythonKernelWrapper: ...@@ -69,6 +82,17 @@ cdef class CythonKernelWrapper:
else: else:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}") raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
# Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
# Check static shape map
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape}, got {tensor_list[buffer_idx].shape}")
# Add dynamic dimension values to kernel arguments # Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
call_args.append(tensor_list[buffer_idx].shape[shape_idx]) call_args.append(tensor_list[buffer_idx].shape[shape_idx])
......
...@@ -18,7 +18,7 @@ class TensorSupplyType(Enum): ...@@ -18,7 +18,7 @@ class TensorSupplyType(Enum):
Auto = 7 Auto = 7
def map_torch_type(intype): def map_torch_type(intype: str) -> torch.dtype:
typemap = { typemap = {
'e4m3_float8': torch.float8_e4m3fn, 'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2, 'e5m2_float8': torch.float8_e5m2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment