Unverified Commit 7248a810 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

feat(cutedsl): add CuTeDSL backend (#1421)



* feat: CuTeDSL backend

* fix: clang-tidy

* fix: clang-format

* fix: ci

* fix: revert example gemm fp8

* fix: remove duplicate code

* fix: switch-case

* fix: fp16 silence

* fix: TVM IR print

* fix: useless tir

* fix: clang-format

* fix: remove tilelang/contrib/cutedsl/.gitignore

* fix: use hexfloat

* fix: gsym guard

* fix: unknown storage sync type

* fix: string literal

* fix: add args guard

* fix: name hint dedup

* fix: better find_kernel_by_pattern

* fix: set libpath for from_database path

* fix: guard buffer.strides

* fix: from guard

* fix: eviction guard

* fix: use thread local tma descs

* fix: ruff

* fix: drop tma_init_cpp

* fix: exc_info

* fix: negative unmatch early return

* fix: rename postproc func and add test

* fix: handle fast math according to pass config

* fix: dyn_sym parse

* fix: wrap_forward

* fix: use tvm_ffi.libinfo instead of cli

* fix: keep signature

* fix: C++ string safety

* fix: mark tma_store_add as unsupported

* fix: tvm version

* resolve ldsm and cpasync issues.

* fix: minor fixes

* fix: parse signature using ast

* fix: guard global_addr

* fix: create tempfile only when necessary

* fix: use logger.execption for exceptions

* fix: guard lib_path and host_func

* fix: remove tma_cpp_init and add timeout for cpp compile

* add timeout for mbarrier_wait.

* fix: _load_kernel_from_disk signature

* resolve codegen issues.

* fix: logger.exception

* add comment for div_by=1

* merge

* fix: reserve cutlass,cute,tl

* fix: guard tma_store

* fix: allow int64 offset in make_tensor_at_offset

* fix: guard barrier

* fix: add comments for div_by=16

* fix: div_by=1 issue

* delete div_by when offset is 0

* use tl.make_tensor when offset is 0

* fix: explicitly check cutedsl target

* fix: use param.torch_dtype()

---------
Co-authored-by: default avataryuxic <yuxic@nvidia.com>
Co-authored-by: default avatarYong <yong@local>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent a6f59f31
"""
Reduce operations for CuTeDSL backend.
Based on tl_templates/cuda/reduce.h
"""
from __future__ import annotations
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Int32, Float32
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import nvvm
from cutlass.cute.arch.nvvm_wrappers import shuffle_sync_op
@dsl_user_op
def min(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32:
return Float32(
nvvm.fmin(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
@dsl_user_op
def max(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32:
return Float32(
nvvm.fmax(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
class SumOp:
"""Sum reduction operator"""
@staticmethod
def __call__(x, y):
return x + y
class MaxOp:
"""Max reduction operator"""
@staticmethod
def __call__(x, y):
return max(x, y)
class MinOp:
"""Min reduction operator"""
@staticmethod
def __call__(x, y):
# Use cutlass.min which is JIT-friendly
return min(x, y)
class BitAndOp:
"""Bitwise AND reduction operator"""
@staticmethod
def __call__(x, y):
return x & y
class BitOrOp:
"""Bitwise OR reduction operator"""
@staticmethod
def __call__(x, y):
return x | y
class BitXorOp:
"""Bitwise XOR reduction operator"""
@staticmethod
def __call__(x, y):
return x ^ y
def bar_sync(barrier_id, number_of_threads):
cute.arch.barrier(barrier_id=barrier_id, number_of_threads=number_of_threads)
def bar_sync_ptx(barrier_id, number_of_threads):
from cutlass._mlir.dialects import llvm
llvm.inline_asm(
None,
[Int32(barrier_id).ir_value(), Int32(number_of_threads).ir_value()],
"bar.sync $0, $1;",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def AllReduce(reducer, threads, scale, thread_offset, all_threads=None):
"""
AllReduce operation implementing warp/block-level reduction.
Based on tl::AllReduce from reduce.h
Args:
reducer: Reducer operator class (SumOp, MaxOp, etc.)
threads: Number of threads participating in reduction
scale: Reduction scale factor
thread_offset: Thread ID offset
all_threads: Total number of threads in block
Returns:
A callable object with run() and run_hopper() methods
"""
class AllReduceInstance:
def __init__(self, reducer, threads, scale, thread_offset: cutlass.Constexpr[int], all_threads: cutlass.Constexpr[int]):
self.reducer = reducer
self.threads = threads
self.scale = scale
self.thread_offset = thread_offset
self.all_threads = all_threads if all_threads is not None else threads
def run(self, x, red_buf: cute.Pointer = None):
"""
Perform all-reduce across threads.
Based on tl::AllReduce<...>::run from reduce.h
"""
offset = self.threads // 2
if offset >= 32:
# Use shared memory for large thread counts
cute.arch.sync_threads()
tidx, _, _ = cute.arch.thread_idx()
cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x
cute.arch.sync_threads()
x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0])
else:
# Use warp shuffle for small thread counts
# Use the pre-existing shuffle_sync_op with butterfly (XOR) mode
other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly)
x = self.reducer()(x, other)
return (
x
if offset == self.scale
else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf)
)
def run_hopper(self, x, red_buf: cute.Pointer = None):
"""
Perform all-reduce on Hopper architecture using bar.sync.
Based on tl::AllReduce<...>::run_hopper from reduce.h
"""
offset = self.threads // 2
tidx, _, _ = cute.arch.thread_idx()
if offset >= 32:
# Use inlined asm for bar.sync to avoid instruction reordering
bar_sync_ptx(1, self.all_threads)
cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x
bar_sync_ptx(2, self.all_threads)
x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0])
else:
# Use warp shuffle for small thread counts
# Use the pre-existing shuffle_sync_op with butterfly (XOR) mode
other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly)
x = self.reducer()(x, other)
return (
x
if offset == self.scale
else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run_hopper(x, red_buf)
)
return AllReduceInstance(reducer, threads, scale, thread_offset, all_threads)
import cutlass.cute as cute
from cutlass.cute.typing import Constexpr
from dataclasses import dataclass
@dataclass(frozen=True)
class dim3:
x: int
y: int
z: int
def ThreadIdx() -> dim3:
return dim3(*cute.arch.thread_idx())
def BlockIdx() -> dim3:
return dim3(*cute.arch.block_idx())
def GridDim() -> dim3:
return dim3(*cute.arch.grid_dim())
@cute.jit
def rasterization2DRow(panel_width: Constexpr[int]) -> dim3:
blockIdx = BlockIdx()
gridDim = GridDim()
block_idx = blockIdx.x + blockIdx.y * gridDim.x
grid_size = gridDim.x * gridDim.y
panel_size = panel_width * gridDim.x
panel_offset = block_idx % panel_size
panel_idx = block_idx // panel_size
total_panel = cute.ceil_div(grid_size, panel_size)
stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.x
col_idx = (gridDim.x - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride)
row_idx = panel_offset % stride + panel_idx * panel_width
return dim3(col_idx, row_idx, blockIdx.z)
@cute.jit
def rasterization2DColumn(panel_width: Constexpr[int]) -> dim3:
blockIdx = BlockIdx()
gridDim = GridDim()
block_idx = blockIdx.x + blockIdx.y * gridDim.x
grid_size = gridDim.x * gridDim.y
panel_size = panel_width * gridDim.y
panel_offset = block_idx % panel_size
panel_idx = block_idx // panel_size
total_panel = cute.ceil_div(grid_size, panel_size)
stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.y
row_idx = (gridDim.y - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride)
col_idx = panel_offset % stride + panel_idx * panel_width
return dim3(col_idx, row_idx, blockIdx.z)
...@@ -197,7 +197,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: ...@@ -197,7 +197,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target) global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda")
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
elif target.kind.name == "hip": elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
else: else:
...@@ -211,7 +212,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> ...@@ -211,7 +212,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tilelang.transform.LowerIntrin()(device_mod) device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target) global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + "_without_compile"
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
elif target.kind.name == "hip": elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target)
elif target.kind.name == "c": elif target.kind.name == "c":
......
...@@ -49,7 +49,7 @@ _Ret = TypeVar("_Ret") ...@@ -49,7 +49,7 @@ _Ret = TypeVar("_Ret")
def compile( def compile(
func: PrimFunc[_KP, _T] = None, func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
...@@ -64,7 +64,7 @@ def compile( ...@@ -64,7 +64,7 @@ def compile(
The TileLang TIR function to compile and wrap. The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython). default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional target : Union[str, Target], optional
...@@ -118,7 +118,7 @@ def compile( ...@@ -118,7 +118,7 @@ def compile(
def par_compile( def par_compile(
funcs: Iterable[PrimFunc[_KP, _T]], funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
...@@ -135,7 +135,7 @@ def par_compile( ...@@ -135,7 +135,7 @@ def par_compile(
The TileLang TIR functions to compile and wrap. The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython). default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional target : Union[str, Target], optional
...@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
""" """
out_idx: list[int] | int | None out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
target: str | Target target: str | Target
target_host: str | Target target_host: str | Target
verbose: bool verbose: bool
...@@ -424,7 +424,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -424,7 +424,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
return kernel return kernel
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
@overload @overload
...@@ -473,7 +473,7 @@ def jit( # This is the new public interface ...@@ -473,7 +473,7 @@ def jit( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None. Target host for cross-compilation. Defaults to None.
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Backend for kernel execution and argument passing. Use "auto" to pick a sensible Backend for kernel execution and argument passing. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython). default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose : bool, optional verbose : bool, optional
......
...@@ -4,3 +4,4 @@ from .ctypes import CtypesKernelAdapter # noqa: F401 ...@@ -4,3 +4,4 @@ from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401
from .torch import MetalKernelAdapter # noqa: F401 from .torch import MetalKernelAdapter # noqa: F401
from .cutedsl import CuTeDSLKernelAdapter # noqa: F401
"""CuTeDSL Backend for TileLang.
This module provides runtime compilation support using NVIDIA's CuTeDSL API.
"""
__all__ = [
"CuTeDSLKernelAdapter",
"TLCuTeDSLSourceWrapper",
"CuTeDSLLibraryGenerator",
"check_cutedsl_available",
]
from .checks import check_cutedsl_available # noqa: F401
from .adapter import CuTeDSLKernelAdapter # noqa: F401
from .wrapper import TLCuTeDSLSourceWrapper # noqa: F401
from .libgen import CuTeDSLLibraryGenerator # noqa: F401
from __future__ import annotations
import logging
from typing import Any, Callable
import torch
from tvm import tir
from tvm.target import Target
from tilelang import tvm as tvm
from tilelang.engine.param import KernelParam
from tilelang.jit.adapter.wrapper import TLPyWrapper
from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available
from tilelang.jit.adapter.cutedsl.libgen import CuTeDSLLibraryGenerator
from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.target import determine_target
from tilelang.jit.adapter.base import BaseKernelAdapter
logger = logging.getLogger(__name__)
class CuTeDSLKernelAdapter(BaseKernelAdapter):
pymodule = None
def __init__(
self,
params: list[KernelParam],
result_idx: list[int],
target: str | Target,
func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None,
host_kernel_source: str | None = None,
device_kernel_source: str | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
check_cutedsl_available()
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self.host_kernel_source = host_kernel_source
self.device_kernel_source = device_kernel_source
if isinstance(func_or_mod, tir.PrimFunc):
gsym = func_or_mod.attrs.get("global_symbol")
if gsym is None:
raise ValueError("PrimFunc is missing required attr 'global_symbol'")
self.ir_module = tvm.IRModule({gsym: func_or_mod})
else:
self.ir_module = func_or_mod
# Cache parameter information during initialization
self.param_dtypes = [param.torch_dtype() for param in params]
self.param_shapes = []
for param in params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
# Keep tir.Var for dynamic dimensions
native_shape.append(dim)
else:
native_shape.append(dim)
self.param_shapes.append(native_shape)
self.dynamic_symbolic_map, self.dynamic_symbolic_order = self._process_dynamic_symbolic()
self.target = Target.canon_target(determine_target(target))
self.verbose = verbose
self.wrapper = TLPyWrapper(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod)
wrapper_result = self.wrapper.wrap(device_kernel_source)
self.host_func = wrapper_result["host_func"]
self.function_names = wrapper_result["function_names"]
self.tma_cpp_init_code = wrapper_result["tma_cpp_init_code"]
self.tma_lib_name = wrapper_result["tma_lib_name"]
self.launcher_cpp_code = wrapper_result.get("launcher_cpp_code", None)
self.launcher_lib_name = wrapper_result.get("launcher_lib_name", None)
self.lib_generator = CuTeDSLLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.device_kernel_source)
self.lib_generator.update_host_func(self.host_func)
self.lib_generator.update_tma_cpp_init_code(self.tma_cpp_init_code)
self.lib_generator.update_tma_lib_name(self.tma_lib_name)
self.lib_generator.update_launcher_cpp_code(self.launcher_cpp_code)
self.lib_generator.update_launcher_lib_name(self.launcher_lib_name)
self.lib_generator.assign_compile_flags(compile_flags)
self.lib_generator.compile_lib()
self.lib_generator.load_lib()
self.libpath = self.lib_generator.libpath
self.device_kernel_source = open(self.libpath).read()
self.pymodule = self.lib_generator.pymodule
self._post_init()
@classmethod
def from_database(
cls,
params: list[KernelParam],
result_idx: list[int],
target: str,
func_or_mod: tir.PrimFunc | tvm.IRModule,
host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.host_kernel_source = host_kernel_source
adapter.device_kernel_source = device_kernel_source
if isinstance(func_or_mod, tir.PrimFunc):
gsym = func_or_mod.attrs.get("global_symbol")
if gsym is None:
raise ValueError("PrimFunc is missing required attr 'global_symbol'")
adapter.ir_module = tvm.IRModule({gsym: func_or_mod})
else:
adapter.ir_module = func_or_mod
# Cache parameter information during initialization
adapter.param_dtypes = [param.torch_dtype() for param in params]
adapter.param_shapes = []
for param in params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
# Keep tir.Var for dynamic dimensions
native_shape.append(dim)
else:
native_shape.append(dim)
adapter.param_shapes.append(native_shape)
adapter.dynamic_symbolic_map, adapter.dynamic_symbolic_order = adapter._process_dynamic_symbolic()
adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose
adapter.lib_generator = CuTeDSLLibraryGenerator(adapter.target, adapter.verbose)
adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.libpath = kernel_lib_path
adapter.kernel_global_source = open(adapter.libpath).read()
adapter.pymodule = adapter.lib_generator.pymodule
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self) -> tuple[dict[tir.Var, tuple[int, int, int]], list[tir.Var]]:
"""Extract information about dynamic symbols from the TIR function.
We follow the same ordering semantics as `TLCUDASourceWrapper.get_dynamic_symbolic_set()`:
1) dynamic symbols in buffer shapes (in prim_func param order)
2) then dynamic symbols in buffer strides
The mapping encodes:
- id=0: shape var -> (0, buffer_param_index, dim_index)
- id=1: stride var -> (1, buffer_param_index, stride_index)
Returns:
(dynamic_symbolic_map, dynamic_symbolic_order)
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] = {}
dynamic_symbolic_order: list[tir.Var] = []
def unique_push_back(v: tir.Var, entry: tuple[int, int, int]):
if v in dynamic_symbolic_map:
return
dynamic_symbolic_map[v] = entry
dynamic_symbolic_order.append(v)
# 1) Shapes
for i, param in enumerate(params):
if param not in buffer_map:
continue
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var):
unique_push_back(shape, (0, i, j))
# 2) Strides
for i, param in enumerate(params):
if param not in buffer_map:
continue
buffer = buffer_map[param]
if buffer.strides is None:
continue
for j, stride in enumerate(buffer.strides):
if isinstance(stride, tir.Var):
unique_push_back(stride, (1, i, j))
return dynamic_symbolic_map, dynamic_symbolic_order
def get_kernel_source(self, kernel_only: bool = True) -> str | None:
"""Get the CUDA kernel source code.
Returns
-------
str | None
The kernel source code, or None if not available
"""
return self.device_kernel_source
def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel."""
result = self.pymodule.call(*args, stream=stream)
# After first call, save cubin to cache if needed
self._save_cubin_to_cache_if_needed()
return result
def _save_cubin_to_cache_if_needed(self):
"""Save cubin to cache directory after first execution.
This is called after the first kernel execution to ensure the generated
cubin file is copied to the cache directory for future reuse.
"""
if getattr(self, "_cubin_saved_to_cache", False):
return
self._cubin_saved_to_cache = True
# Check if we have a cache path (set by kernel_cache)
cache_path = getattr(self, "_cache_path", None)
if cache_path is None:
return
import os
import shutil
# Source cubin path (in temp directory)
src_py_path = self.libpath
src_py_stem = os.path.splitext(os.path.basename(src_py_path))[0]
src_dir = os.path.dirname(src_py_path)
src_cubin_path = os.path.join(src_dir, f"{src_py_stem}.cubin")
if not os.path.exists(src_cubin_path):
return
# Destination cubin path (in cache directory)
dst_cubin_path = os.path.join(cache_path, "kernel.cubin")
if os.path.exists(dst_cubin_path):
return
# Copy cubin to cache
try:
shutil.copy2(src_cubin_path, dst_cubin_path)
logger.debug(f"Saved CuTeDSL cubin to cache: {dst_cubin_path}")
except Exception as e:
logger.warning(f"Failed to save cubin to cache: {e}", exc_info=True)
def _wrap_forward_from_prebuild_lib(self, *ins: Any, stream: int | None = None):
"""High-level wrapper for kernel execution.
Handles:
1. Input validation
2. Output tensor allocation
3. Dynamic shape resolution
4. CUDA stream management
Args:
ins: Input arguments (may include scalars and tensors)
stream: Optional CUDA stream for asynchronous execution
Returns:
Single tensor or list of tensors containing the kernel results
"""
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
# Materialize args in PrimFunc param order (inputs + allocated outputs)
ins_idx = 0
param_values: list[Any] = [None] * len(self.params)
for i in range(len(self.params)):
if i in self.result_idx:
continue
param_values[i] = ins[ins_idx]
ins_idx += 1
first_tensor = next((v for v in param_values if isinstance(v, torch.Tensor)), None)
if first_tensor is None:
raise ValueError("Expected at least one torch.Tensor argument to infer CUDA device")
args: list[Any] = []
# tensor pointers
for i in range(len(self.params)):
if i in self.result_idx:
dtype = self.param_dtypes[i]
shape = []
# Now working with native Python list, no FFI calls needed
for s in self.param_shapes[i]:
if isinstance(s, tir.Var):
ref_id, ref_param_idx, ref_dim_idx = self.dynamic_symbolic_map[s]
ref_val = param_values[ref_param_idx]
if not isinstance(ref_val, torch.Tensor):
raise TypeError(f"Dynamic shape/stride var {s} refers to a non-tensor param at index {ref_param_idx}")
if ref_id == 0:
shape.append(ref_val.shape[ref_dim_idx])
elif ref_id == 1:
# Stride vars are not expected in output shapes, but handle defensively.
shape.append(ref_val.stride()[ref_dim_idx])
else:
raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}")
else: # Already converted to Python int during initialization
shape.append(s)
tensor = torch.empty(*shape, dtype=dtype, device=first_tensor.device)
param_values[i] = tensor
else:
tensor = param_values[i]
args.append(tensor)
# dynamic symbolics
for sym in self.dynamic_symbolic_order:
ref_id, buffer_idx, dim_idx = self.dynamic_symbolic_map[sym]
ref_val = param_values[buffer_idx]
if not isinstance(ref_val, torch.Tensor):
raise TypeError(f"Dynamic symbolic var {sym} refers to a non-tensor param at index {buffer_idx}")
if ref_id == 0:
args.append(ref_val.shape[dim_idx])
elif ref_id == 1:
args.append(ref_val.stride()[dim_idx])
else:
raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}")
# if stream is not None, we need to pass the stream to the library
if stream is None:
if str(self.target).startswith("cuda") and torch.cuda.is_available():
stream = torch.cuda.current_stream().cuda_stream
else:
stream = 0
self._forward_from_prebuild_lib(*args, stream=stream)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]:
"""Convert to a PyTorch-compatible function.
Returns
-------
Callable[..., torch.Tensor | list[torch.Tensor]]
A callable function that takes tensors and returns tensor(s)
"""
return self._wrap_forward_from_prebuild_lib
@property
def prim_func(self) -> tir.PrimFunc:
"""Returns the primary TIR function from the IR module."""
return retrieve_func_from_module(self.ir_module)
from __future__ import annotations
import re
from importlib import metadata as _importlib_metadata
from importlib.util import find_spec as _find_spec
import os
_CUTEDSL_PUBLIC_DIST = "nvidia-cutlass-dsl"
_CUTEDSL_MIN_VERSION = (4, 3, 1)
_VERSION_TRIPLE_RE = re.compile(r"(\d+)\.(\d+)\.(\d+)")
def _parse_version_triple(version_str: str) -> tuple[int, int, int] | None:
"""Parse a best-effort (major, minor, patch) triple from a version string.
We intentionally avoid importing heavy/optional version parsers. For our
minimum requirement (>= 4.3.1), a numeric triple comparison is sufficient.
"""
m = _VERSION_TRIPLE_RE.search(version_str)
if not m:
return None
return int(m.group(1)), int(m.group(2)), int(m.group(3))
def _min_version_str() -> str:
return ".".join(map(str, _CUTEDSL_MIN_VERSION))
def _requirement_spec() -> str:
return f"{_CUTEDSL_PUBLIC_DIST}>={_min_version_str()}"
def check_cutedsl_available() -> None:
"""Fail fast if the CuTeDSL backend cannot be used in this Python environment.
Policy:
- If the public distribution `nvidia-cutlass-dsl` is installed, require version >= a minimum supported version.
- Regardless of distribution metadata, require that `cutlass.cute` is importable.
This intentionally does not mention or special-case any internal distributions.
"""
# 1) Version gate (only when the public dist metadata is present)
try:
dist_version = _importlib_metadata.version(_CUTEDSL_PUBLIC_DIST)
except _importlib_metadata.PackageNotFoundError:
dist_version = None
except Exception:
# Metadata is best-effort; don't block internal/nonstandard installs here.
dist_version = None
if dist_version is not None:
parsed = _parse_version_triple(dist_version)
if parsed is None or parsed < _CUTEDSL_MIN_VERSION:
req = _requirement_spec()
raise ImportError(
f"CuTeDSL backend requires `{req}`, but found version `{dist_version}`. Please run: `pip install -U '{req}'`."
)
# 2) Capability probe: keep it cheap.
# Importing cutlass/cute can be expensive and defeats our lazy-import design,
# especially on cache hits. We only require that the module is importable.
cutlass_spec = _find_spec("cutlass")
if cutlass_spec is None:
req = _requirement_spec()
raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).")
# Avoid find_spec("cutlass.cute") which can be surprisingly expensive.
# Instead, check for a 'cute' submodule/package under cutlass's search locations.
locs = getattr(cutlass_spec, "submodule_search_locations", None)
has_cute = False
if locs:
for base in locs:
if os.path.isdir(os.path.join(base, "cute")) or os.path.isfile(os.path.join(base, "cute.py")):
has_cute = True
break
if not has_cute:
req = _requirement_spec()
raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).")
"""CuTeDSL Library Generator for TileLang.
This module provides library generation functionality for the CuTeDSL backend.
"""
from __future__ import annotations
import importlib.util
import os
import tempfile
import subprocess
from tvm.target import Target
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cutedsl_target
class CuTeDSLLibraryGenerator(LibraryGenerator):
host_func: str | None = None
tma_cpp_init_code: str | None = None
tma_lib_name: str | None = None
launcher_cpp_code: str | None = None
launcher_lib_name: str | None = None
pymodule = None
def __init__(self, target: Target, verbose: bool = False):
super().__init__(target, verbose)
@staticmethod
def import_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
self.host_func = host_func
def update_tma_cpp_init_code(self, tma_cpp_init_code: str):
self.tma_cpp_init_code = tma_cpp_init_code
def update_tma_lib_name(self, tma_lib_name: str):
self.tma_lib_name = tma_lib_name
def update_launcher_cpp_code(self, launcher_cpp_code: str):
self.launcher_cpp_code = launcher_cpp_code
def update_launcher_lib_name(self, launcher_lib_name: str):
self.launcher_lib_name = launcher_lib_name
def load_lib(self, lib_path: str | None = None):
if lib_path is None:
if self.libpath is None:
raise RuntimeError("CuTeDSLLibraryGenerator.libpath is not set; call compile_lib() first or pass lib_path explicitly.")
lib_path = self.libpath
self.pymodule = self.import_from_file("kernel", lib_path)
def compile_lib(self, timeout: float = None):
if self.host_func is None:
raise RuntimeError("CuTeDSLLibraryGenerator.host_func is not set; call update_host_func() before compile_lib().")
target = self.target
if is_cutedsl_target(target):
# Use a dedicated temp directory per kernel so CuTeDSL artifacts (e.g. kept .cubin)
# never pollute user CWD, and are easy to locate alongside the generated module.
work_dir = tempfile.mkdtemp(prefix="tilelang_cutedsl_")
src_path = os.path.join(work_dir, "kernel.py")
with open(src_path, "w") as f:
# Note: lib_code (containing @cute.kernel definitions) is embedded
# inside host_func's _generate_cubin_if_needed function, so we only
# write host_func here. This ensures cute imports are lazy-loaded.
f.write(self.host_func)
# Compile C++ launcher library if needed
if self.launcher_cpp_code is not None:
with tempfile.NamedTemporaryFile(
mode="w",
suffix=".cpp",
delete=False,
) as launcher_src:
launcher_src.write(self.launcher_cpp_code)
launcher_src_path = launcher_src.name
# Generate launcher lib under the same directory as the source file
launcher_lib_path = os.path.join(os.path.dirname(src_path), self.launcher_lib_name)
# Get TVM FFI compiler flags using tvm_ffi.libinfo API
try:
import tvm_ffi.libinfo
include_paths = tvm_ffi.libinfo.include_paths()
tvm_cxxflags = [f"-I{path}" for path in include_paths]
lib_path = tvm_ffi.libinfo.find_libtvm_ffi()
lib_dir = os.path.dirname(lib_path)
tvm_ldflags = [f"-L{lib_dir}", "-ltvm_ffi"]
except (ImportError, RuntimeError):
# tvm_ffi unavailable or libinfo functions failed
tvm_cxxflags = []
tvm_ldflags = []
# Compile with nvcc (need CUDA driver API)
compile_cmd = [
"nvcc",
"-shared",
"-Xcompiler=-fPIC",
"-lcuda",
*tvm_cxxflags,
*tvm_ldflags,
"-o",
launcher_lib_path,
launcher_src_path,
]
result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True, timeout=timeout)
if result.returncode != 0:
raise RuntimeError(f"Failed to compile C++ launcher: {result.stderr}")
self.launcher_libpath = launcher_lib_path
self.launcher_libname = self.launcher_lib_name
self.srcpath = src_path
self.libpath = src_path
else:
raise ValueError(f"Unsupported target: {target}")
"""CuTeDSL Source Wrapper for TileLang.
This module provides C++ kernel launcher generation for the CuTeDSL backend.
Key features:
- Automatic C++ launcher generation with CUDA Driver API
- TMA descriptors on HOST memory, passed via __grid_constant__ (no device copy needed)
- cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space
- Support for single and multiple kernel launches
- Complete cache system integration
"""
from __future__ import annotations
from typing import Any, ClassVar
from tvm import IRModule
from tvm.target import Target
from tvm.tir.stmt_functor import post_order_visit
from tilelang import tvm as tvm
from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper
from tilelang.jit.adapter.utils import (
extract_python_func_declaration,
pythonic_expr,
parse_tma_descriptor_args,
)
# =============================================================================
# C++ LAUNCHER TEMPLATES (using named parameters for clarity)
# =============================================================================
# TMA single descriptor initialization template (writes to caller-provided host array)
# No device copy needed - cuLaunchKernel handles __grid_constant__ params automatically
CPP_TMA_DESC_INIT_TEMPLATE = """\
// Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name})
{{
uint64_t globalDim[{rank}] = {{{global_dim_values}}};
uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
uint32_t boxDim[{rank}] = {{{box_dim_values}}};
uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};
result = cuTensorMapEncodeTiled(
&tma_descs[{desc_idx}],
static_cast<CUtensorMapDataType>({dtype}),
{rank},
reinterpret_cast<void*>({tensor_name}_ptr),
globalDim,
globalStrides,
boxDim,
elemStrides,
static_cast<CUtensorMapInterleave>({interleave}),
static_cast<CUtensorMapSwizzle>({swizzle}),
static_cast<CUtensorMapL2promotion>({l2_promotion}),
static_cast<CUtensorMapFloatOOBfill>({oob_fill})
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to encode TMA descriptor {desc_idx}: " << result << "\\n";
return result;
}}
}}
"""
# TMA single im2col descriptor initialization template (writes to caller-provided host array)
# Align field ordering with NVRTC wrapper (cuTensorMapEncodeIm2col signature).
CPP_TMA_IM2COL_DESC_INIT_TEMPLATE = """\
// Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) [im2col]
{{
uint64_t globalDim[{rank}] = {{{global_dim_values}}};
uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};
int32_t lowerCorner[{rank_minus_two}] = {{{lower_corner_values}}};
int32_t upperCorner[{rank_minus_two}] = {{{upper_corner_values}}};
result = cuTensorMapEncodeIm2col(
&tma_descs[{desc_idx}],
static_cast<CUtensorMapDataType>({dtype}),
{rank},
reinterpret_cast<void*>({tensor_name}_ptr),
globalDim,
globalStrides,
lowerCorner,
upperCorner,
static_cast<uint32_t>({channels_per_pixel}),
static_cast<uint32_t>({pixels_per_column}),
elemStrides,
static_cast<CUtensorMapInterleave>({interleave}),
static_cast<CUtensorMapSwizzle>({swizzle}),
static_cast<CUtensorMapL2promotion>({l2_promotion}),
static_cast<CUtensorMapFloatOOBfill>({oob_fill})
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to encode TMA im2col descriptor {desc_idx}: " << result << "\\n";
return result;
}}
}}
"""
# TMA initialization function template (writes to caller-provided host array)
# __grid_constant__ allows kernel to receive TMA descriptor by value via param space
CPP_TMA_INIT_FUNC_TEMPLATE = """\
CUresult tma_init(CUtensorMap* tma_descs, {func_args}) {{
// Initialize {num_descs} TMA descriptor(s) in caller-provided host array
// cuLaunchKernel will copy 128-byte CUtensorMap to kernel param space automatically
CUresult result;
{desc_init_code}
return CUDA_SUCCESS;
}}
"""
# Kernel initialization template
CPP_KERNEL_INIT_TEMPLATE = """\
// Find and configure kernel {kernel_idx}: {kernel_name}
result = find_kernel_by_pattern(g_module, "{kernel_name}", &g_kernels[{kernel_idx}]);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to find kernel {kernel_name}: " << result << "\\n";
return result;
}}
if ({smem_size} > 0) {{
result = cuFuncSetAttribute(g_kernels[{kernel_idx}],
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
{smem_size});
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to set smem for {kernel_name}: " << result << "\\n";
return result;
}}
}}
"""
# TMA launch initialization template (host memory mode - uses __grid_constant__)
# Kernel receives TMA descriptor by value: .param .align 128 .b8 xxx_param[128]
CPP_TMA_LAUNCH_INIT_TEMPLATE = """\
// Declare stack-local TMA descriptor array (eliminates concurrency race)
CUtensorMap tma_descs[{num_tma_descs}];
// Initialize TMA descriptors (HOST memory - passed via __grid_constant__)
// NOTE: We intentionally do NOT reuse/cached descriptors across launches.
// Pointer-only reuse is a correctness trap (shape/stride may change with same ptr),
// and correctness beats micro-optimizations.
result = tma_init(tma_descs, {tma_tensor_args});
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to initialize TMA descriptors: " << result << "\\n";
return result;
}}
"""
# Kernel launch template
CPP_KERNEL_LAUNCH_TEMPLATE = """\
// Launch kernel {kernel_idx}: {kernel_name}
{{
void* args[] = {{{kernel_args}}};
result = cuLaunchKernel(
g_kernels[{kernel_idx}],
{grid_x}, {grid_y}, {grid_z},
{block_x}, {block_y}, {block_z},
{smem_size},
stream,
args,
nullptr
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to launch kernel {kernel_name}: " << result << "\\n";
return result;
}}
}}
"""
# Complete C++ launcher template
CPP_LAUNCHER_TEMPLATE = """\
#include <cuda.h>
#include <cstdint>
#include <iostream>
#include <fstream>
#include <vector>
#include <cstring>
#include <string>
// TVM Headers
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
// Cached module handle
static CUmodule g_module = nullptr;
static bool g_module_initialized = false;
// Cached kernel functions
static CUfunction g_kernels[{num_kernels}] = {{nullptr}};
static bool g_kernels_initialized = false;
// Find kernel by pattern (substring match, prefer base name over _N variants)
CUresult find_kernel_by_pattern(CUmodule module, const char* pattern, CUfunction* out_func) {{
CUresult result;
unsigned int num_funcs = 0;
result = cuModuleGetFunctionCount(&num_funcs, module);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to get function count: " << result << "\\n";
return result;
}}
std::vector<CUfunction> func_list(num_funcs);
result = cuModuleEnumerateFunctions(func_list.data(), num_funcs, module);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to enumerate functions: " << result << "\\n";
return result;
}}
// Collect substring matches, separating base name from _N variants
std::vector<std::pair<std::string, CUfunction>> base_matches; // pattern not followed by _digit
std::vector<std::pair<std::string, CUfunction>> variant_matches; // pattern followed by _digit
size_t pattern_len = std::strlen(pattern);
for (unsigned int i = 0; i < num_funcs; i++) {{
const char* func_name = nullptr;
result = cuFuncGetName(&func_name, func_list[i]);
if (result != CUDA_SUCCESS || func_name == nullptr) {{
std::cerr << "Failed to get function name: " << result << "\\n";
return result;
}}
std::string name_str(func_name);
size_t pos = name_str.find(pattern);
if (pos != std::string::npos) {{
// Found substring match
size_t after_pattern = pos + pattern_len;
// Check what follows the pattern
if (after_pattern < name_str.length() &&
name_str[after_pattern] == '_' &&
after_pattern + 1 < name_str.length() &&
std::isdigit(name_str[after_pattern + 1])) {{
// Pattern followed by _digit (e.g., "main_kernel_1")
variant_matches.push_back({{name_str, func_list[i]}});
}} else {{
// Pattern not followed by _digit (e.g., "main_kernel" itself)
base_matches.push_back({{name_str, func_list[i]}});
}}
}}
}}
// Decision logic: prefer base matches over variant matches
if (!base_matches.empty()) {{
if (base_matches.size() == 1) {{
*out_func = base_matches[0].second;
return CUDA_SUCCESS;
}}
// Multiple base matches - ambiguous
std::cerr << "Error: Pattern '" << pattern << "' matched " << base_matches.size()
<< " base kernels (ambiguous). Matches found:\\n";
for (const auto& match : base_matches) {{
std::cerr << " - " << match.first << "\\n";
}}
std::cerr << "Please use a more specific pattern.\\n";
return CUDA_ERROR_NOT_FOUND;
}}
// No base matches, try variant matches
if (!variant_matches.empty()) {{
if (variant_matches.size() == 1) {{
*out_func = variant_matches[0].second;
return CUDA_SUCCESS;
}}
// Multiple variant matches - ambiguous
std::cerr << "Error: Pattern '" << pattern << "' matched " << variant_matches.size()
<< " variant kernels (ambiguous). Matches found:\\n";
for (const auto& match : variant_matches) {{
std::cerr << " - " << match.first << "\\n";
}}
std::cerr << "Please use a more specific pattern (e.g., '" << pattern << "_1').\\n";
return CUDA_ERROR_NOT_FOUND;
}}
// No matches at all
std::cerr << "Failed to find kernel matching pattern '" << pattern << "'\\n";
return CUDA_ERROR_NOT_FOUND;
}}
// Initialize CUDA module (called once on first launch)
static CUresult tilelang_init_cuda_module(const std::string& cubin_path) {{
if (g_module_initialized) return CUDA_SUCCESS;
CUresult result;
result = cuInit(0);
if (result != CUDA_SUCCESS) return result;
std::ifstream cubin_file(cubin_path.c_str(), std::ios::binary);
if (!cubin_file) {{
std::cerr << "Failed to open cubin file: " << cubin_path << "\\n";
return CUDA_ERROR_FILE_NOT_FOUND;
}}
std::vector<char> cubin_data((std::istreambuf_iterator<char>(cubin_file)),
std::istreambuf_iterator<char>());
cubin_file.close();
if (cubin_data.empty()) {{
std::cerr << "Empty cubin file: " << cubin_path << "\\n";
return CUDA_ERROR_INVALID_IMAGE;
}}
result = cuModuleLoadData(&g_module, cubin_data.data());
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to load CUDA module: " << result << "\\n";
return result;
}}
g_module_initialized = true;
return CUDA_SUCCESS;
}}
// Initialize all kernel functions (called once after module load)
static CUresult tilelang_init_kernels() {{
if (g_kernels_initialized) return CUDA_SUCCESS;
CUresult result;
{kernel_inits}
g_kernels_initialized = true;
return CUDA_SUCCESS;
}}
// TMA descriptor initialization (host-side)
{tma_init_func}
// Main kernel launcher
extern "C" CUresult launch_kernel({launch_func_sig}, uint64_t _stream, tvm::ffi::Bytes cubin_path) {{
CUresult result;
std::string cubin_path_str(reinterpret_cast<const char*>(cubin_path.data()), cubin_path.size());
result = tilelang_init_cuda_module(cubin_path_str);
if (result != CUDA_SUCCESS) return result;
result = tilelang_init_kernels();
if (result != CUDA_SUCCESS) return result;
{get_ptr_code}
CUstream stream = (CUstream)_stream;
{tma_init_in_launch}
{kernel_launches}
return CUDA_SUCCESS;
}}
// Cleanup function
extern "C" CUresult cleanup_module() {{
if (g_module_initialized && g_module != nullptr) {{
cuModuleUnload(g_module);
g_module = nullptr;
g_module_initialized = false;
}}
g_kernels_initialized = false;
return CUDA_SUCCESS;
}}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_kernel, launch_kernel);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(cleanup_module, cleanup_module);
"""
# =============================================================================
# PYTHON CUBIN GENERATION TEMPLATES
# =============================================================================
# TMA descriptor atom initialization template
CUBIN_TMA_ATOM_INIT_TEMPLATE = """\
{desc_name} = tl.Gemm_SM90.get_tma_atom(__fake_tensor__, (32, 32))"""
# Kernel launch call template
CUBIN_KERNEL_LAUNCH_TEMPLATE = """\
{function_name}({call_args}).launch(
grid=[{grid_x}, {grid_y}, {grid_z}],
block=[{block_x}, {block_y}, {block_z}],
smem={smem_size},
stream=stream,
)"""
# Fake tensor creation template
CUBIN_FAKE_TENSOR_TEMPLATE = """\
__fake_{arg_name}__ = make_fake_compact_tensor(_DTYPE_MAP[str({arg_name}.dtype)], {arg_name}.shape, stride_order={arg_name}.dim_order()[::-1], assumed_align=16)"""
# Complete cubin generation code template
# {lib_code} contains the @cute.kernel definitions and is embedded here
CUBIN_GEN_CODE_TEMPLATE = """\
{lib_code}
@cute.jit
def kernel_wrapper({wrapper_args}):
{tma_init_code}{kernel_launches}
# Compile kernels to generate cubin
{fake_tensor_code}{fake_tma_tensor_code} __fake_stream__ = make_fake_stream()
# Always generate cubin under a unique staging directory to avoid concurrent
# processes clobbering each other's intermediate artifacts.
_staging_dir = Path(tempfile.mkdtemp(
prefix=Path(__file__).stem + ".cubin.staging.",
dir=_module_dir,
))
try:
_kernel_wrapper = cute.compile(
kernel_wrapper,
{compile_args},
options=f"--enable-tvm-ffi --keep-cubin --dump-dir={{_staging_dir.as_posix()}}",
)
# CuTeDSL generates a long, mangled cubin filename that includes argument/type info,
# e.g. "cutlass_kernel_wrapper_FakeTensor...sm_90a.cubin". We expect exactly one cubin.
_cubin_files = sorted(_staging_dir.glob("*.cubin"), key=lambda p: p.stat().st_mtime)
if len(_cubin_files) != 1:
raise RuntimeError(
f"Expected exactly one .cubin under {{_staging_dir}}, got {{len(_cubin_files)}}: {{_cubin_files}}"
)
os.replace(_cubin_files[0], _cubin_path)
finally:
shutil.rmtree(_staging_dir, ignore_errors=True)"""
# =============================================================================
# PYTHON HOST FUNCTION TEMPLATE
# =============================================================================
PYTHON_HOST_FUNC_TEMPLATE = """\
import os
from pathlib import Path
# Minimal imports for runtime (no cutlass/cute - only needed for cubin generation)
import tvm.runtime as runtime
_cpp_launcher = None
_cpp_launcher_lib = None
_cubin_generated = False
# Pre-compute paths - cubin is stored alongside the launcher .so
# Use module basename to avoid conflicts when multiple kernels run concurrently
# e.g., "/tmp/tmp8liu__ho.py" -> "/tmp/tmp8liu__ho.cubin"
# "kernel.py" (in cache) -> "kernel.cubin"
_module_dir = Path(os.path.dirname(__file__))
_cubin_path = _module_dir / (Path(__file__).stem + ".cubin")
_cubin_path_bytes = _cubin_path.as_posix().encode('utf-8')
_cubin_needs_generation = not _cubin_path.exists()
def _generate_cubin_if_needed({cubin_gen_params}):
\"\"\"Generate cubin file on first call.
All CuTeDSL imports are inside this function to avoid slow
module-level initialization when loading from cache.
\"\"\"
global _cubin_generated, _cubin_path
# Lazy import CuTeDSL only when cubin generation is needed
from cuda.bindings.driver import CUstream
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import make_fake_stream, make_fake_compact_tensor
import tilelang.contrib.cutedsl as tl
# We rely on CuTeDSL's keep-cubin artifact rather than custom extraction.
import tempfile
import shutil
_DTYPE_MAP = {{
"torch.float32": cutlass.Float32,
"torch.float16": cutlass.Float16,
"torch.bfloat16": cutlass.BFloat16,
"torch.float8_e4m3fnuz": cutlass.Float8E4M3FN,
"torch.float8_e4m3fn": cutlass.Float8E4M3FN,
"torch.float8_e5m2": cutlass.Float8E5M2,
"torch.float64": cutlass.Float64,
"torch.int64": cutlass.Int64,
"torch.int32": cutlass.Int32,
"torch.uint32": cutlass.Uint32,
"torch.bool": cutlass.Boolean,
"torch.int8": cutlass.Int8,
"torch.uint8": cutlass.Uint8,
"torch.int16": cutlass.Int16,
"torch.uint16": cutlass.Uint16,
"torch.uchar": cutlass.Uint8,
}}
{cubin_gen_code}
_cubin_generated = True
def _load_cpp_launcher():
\"\"\"Load C++ kernel launcher.\"\"\"
global _cpp_launcher, _cpp_launcher_lib
if _cpp_launcher is not None:
return _cpp_launcher
lib_path = os.path.join(os.path.dirname(__file__), "{launcher_lib_name}")
if not os.path.exists(lib_path):
raise FileNotFoundError(f"Launcher not found: {{lib_path}}")
_cpp_launcher_lib = runtime.load_module(lib_path)
_cpp_launcher = _cpp_launcher_lib["launch_kernel"]
return _cpp_launcher
def call({call_func_params}, stream):
\"\"\"Kernel dispatch function.\"\"\"
global _cubin_path_bytes, _cubin_needs_generation
if _cubin_needs_generation:
_generate_cubin_if_needed({cubin_gen_call_args})
_cubin_needs_generation = False
{arg_prep_code}
launcher = _load_cpp_launcher()
result = launcher({launcher_call_args}, stream, _cubin_path_bytes)
if result != 0:
raise RuntimeError(f"Kernel launch failed with CUDA error {{result}}")
"""
# =============================================================================
# WRAPPER CLASS
# =============================================================================
class TLCuTeDSLSourceWrapper(TLCUDASourceWrapper):
"""Wrapper class for TileLang CuTe DSL backend with C++ launcher.
Generates optimized C++ launcher code that:
- Loads cubin via CUDA Driver API
- Passes TMA descriptors by value (host-side, no device copy)
- Launches kernels with minimal Python overhead
- Supports both single and multiple kernel scenarios
"""
_TYPE_MAP: ClassVar[dict[str, str]] = {
"float32": "cutlass.Float32",
"float16": "cutlass.Float16",
"bfloat16": "cutlass.BFloat16",
"float8_e4m3": "cutlass.Float8E4M3",
"float8_e5m2": "cutlass.Float8E5M2",
"float64": "cutlass.Float64",
"int64": "cutlass.Int64",
"int32": "cutlass.Int32",
"uint32": "cutlass.Uint32",
"bool": "cutlass.Boolean",
"int8": "cutlass.Int8",
"uint8": "cutlass.Uint8",
"int16": "cutlass.Int16",
"uint16": "cutlass.Uint16",
"uchar": "cutlass.Uint8",
}
# C++ launcher code must not depend on cutlass Python types.
# Use plain C/C++ types for expression rendering inside generated .cpp.
_CXX_TYPE_MAP: ClassVar[dict[str, str]] = {
"float32": "float",
"float64": "double",
"int64": "int64_t",
"int32": "int32_t",
"uint32": "uint32_t",
"bool": "bool",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uint16": "uint16_t",
}
_CTYPES_MAP: ClassVar[dict[str, str]] = {
"buffer": "ctypes.c_uint64",
"cutlass.Float32": "ctypes.c_float",
"cutlass.Float16": "ctypes.c_uint16",
"cutlass.Float64": "ctypes.c_double",
"cutlass.Int64": "ctypes.c_int64",
"cutlass.Int32": "ctypes.c_int32",
"cutlass.Uint32": "ctypes.c_uint32",
"cutlass.Int8": "ctypes.c_int8",
"cutlass.Uint8": "ctypes.c_uint8",
"cutlass.Int16": "ctypes.c_int16",
"cutlass.Uint16": "ctypes.c_uint16",
"int": "ctypes.c_int32",
}
_generated_host_func: str | None = None
_launcher_lib_name: str | None = None
def __init__(
self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None,
):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
# =========================================================================
# Properties
# =========================================================================
@property
def host_func(self):
"""Override parent's host_func to return generated Python code."""
if self._generated_host_func is not None:
return self._generated_host_func
return super().host_func
@host_func.setter
def host_func(self, value):
"""Allow setting generated host function code."""
self._generated_host_func = value
# =========================================================================
# Utility Methods
# =========================================================================
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
"""Convert TVM expression to Python string."""
return pythonic_expr(expr, self._TYPE_MAP, floor_div_op="//")
def _cxx_expr(self, expr: tvm.tir.PrimExpr) -> str:
"""Convert TVM expression to C++ string for generated launcher code."""
return pythonic_expr(expr, self._CXX_TYPE_MAP)
@staticmethod
def _cxx_cast(ctype: str, expr_str: str) -> str:
return f"static_cast<{ctype}>({expr_str})"
def _collect_function_args(self) -> tuple[list[dict], list[str]]:
"""Collect all function arguments from primary function.
Returns:
Tuple of (function_args, buffer_args)
"""
function_args = []
buffer_args = []
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({"name": buffer.data.name, "type": "buffer"})
buffer_args.append(buffer.data.name)
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
else:
raise ValueError(f"Parameter {param} not in buffer map")
existing_names = {arg["name"] for arg in function_args}
for dyn_sym in self.get_dynamic_symbolic_set(self.prim_func):
dyn_sym_name, dyn_sym_dtype = dyn_sym if isinstance(dyn_sym, tuple) else (dyn_sym, "int32")
if dyn_sym_name in existing_names:
continue
existing_names.add(dyn_sym_name)
function_args.append({"name": dyn_sym_name, "type": self._TYPE_MAP.get(dyn_sym_dtype, "int")})
return function_args, buffer_args
@staticmethod
def _extract_func_call_args(
declaration: str,
function_args: list[dict],
function_params: list,
desc_name_map: dict[str, str] | None = None,
desc_name_var_map: dict[str, tvm.tir.Var] | None = None,
) -> list[tuple[str, str]]:
"""Extract function call arguments from Python function declaration."""
def maybe_desc(name: str | tuple[str, str], param_names: list[str], i: int):
name_str = name if isinstance(name, str) else name[0]
param = param_names[i]
if not (param == name_str + "_desc" or param.startswith(name_str + "_desc_")):
return False
if desc_name_map is not None:
desc_name_map[param] = name_str
return True
def extract_param_names_ast(decl: str) -> list[str] | None:
"""Extract parameter names using AST parsing."""
import ast
import warnings
try:
# Build a syntactically valid function by adding a body
func_stub = decl.rstrip()
if not func_stub.endswith(":"):
func_stub += ":"
func_stub += "\n pass"
# Parse and locate the FunctionDef
tree = ast.parse(func_stub)
func_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_def = node
break
if func_def is None:
return None
# Extract parameter names, skipping 'self'
param_names = []
for arg in func_def.args.args:
if arg.arg != "self":
param_names.append(arg.arg)
return param_names
except Exception as e:
warnings.warn(f"AST parsing failed for function declaration, falling back to split-based parsing: {e}", stacklevel=2)
return None
def extract_param_names_split(decl: str) -> list[str]:
"""Fallback: extract parameter names using naive split-based parsing."""
paren_start = decl.find("(")
paren_end = decl.rfind(")")
if paren_start == -1 or paren_end == -1:
return []
params_str = decl[paren_start + 1 : paren_end].strip()
if not params_str:
return []
param_parts = params_str.split(",")
param_names = []
for param in param_parts:
param = param.strip()
if not param or param == "self":
continue
if ":" in param:
param_name = param.split(":")[0].strip()
else:
param_name = param.strip()
param_names.append(param_name)
return param_names
# Try AST-based extraction first, fallback to split-based
param_names = extract_param_names_ast(declaration)
if param_names is None:
param_names = extract_param_names_split(declaration)
call_args = []
for i, param_name in enumerate(param_names):
for arg in function_args:
if arg["name"] == param_name:
call_args.append((param_name, arg["type"]))
elif maybe_desc(arg["name"], param_names, i):
call_args.append((param_name, "None"))
if desc_name_var_map is not None and function_params is not None:
assert len(call_args) <= len(function_params)
desc_name_var_map[param_name] = function_params[len(call_args) - 1]
return call_args
@staticmethod
def _filter_non_descriptor_args(
call_args: list[tuple[str, str]], desc_names: list[str], tma_tensors: list[str]
) -> list[tuple[str, str]]:
"""Filter out descriptor arguments."""
filtered = []
for arg_name, arg_type in call_args:
if "desc" in arg_name and arg_name in desc_names:
continue
if arg_name in tma_tensors:
continue
filtered.append((arg_name, arg_type))
return filtered
# =========================================================================
# TMA Descriptor Code Generation
# =========================================================================
def _generate_tma_desc_init(self, desc_name: str, desc_idx: int, tensor_name: str, info: dict) -> str:
"""Generate single TMA descriptor initialization code."""
if info.get("is_img2col", False):
rank = info["tensor_rank"]
return CPP_TMA_IM2COL_DESC_INIT_TEMPLATE.format(
desc_idx=desc_idx,
desc_name=desc_name,
tensor_name=tensor_name,
rank=rank,
stride_rank=rank - 1,
rank_minus_two=rank - 2,
global_dim_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_dim"]),
global_stride_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_stride"][1:]),
elem_stride_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["element_strides"]),
lower_corner_values=", ".join(self._cxx_cast("int32_t", self._cxx_expr(x)) for x in info["lower_corner"]),
upper_corner_values=", ".join(self._cxx_cast("int32_t", self._cxx_expr(x)) for x in info["upper_corner"]),
# Match NVRTC wrapper naming: channelsPerPixel then pixelsPerColumn
channels_per_pixel=info["smem_box_channel"],
pixels_per_column=info["smem_box_pixel"],
dtype=info["dtype"],
interleave=info["interleave"],
swizzle=info["swizzle"],
l2_promotion=info["l2Promotion"],
oob_fill=info["oobFill"],
)
return CPP_TMA_DESC_INIT_TEMPLATE.format(
desc_idx=desc_idx,
desc_name=desc_name,
tensor_name=tensor_name,
rank=info["tensor_rank"],
global_dim_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_dim"]),
stride_rank=info["tensor_rank"] - 1,
global_stride_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_stride"][1:]),
box_dim_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["box_dim"]),
elem_stride_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["element_strides"]),
dtype=info["dtype"],
interleave=info["interleave"],
swizzle=info["swizzle"],
l2_promotion=info["l2Promotion"],
oob_fill=info["oobFill"],
)
def _generate_tma_init_func(
self,
desc_names: list[str],
tensor_args: list[str],
tensor_arg_map: dict[str, tuple[str, int]],
scalar_args: list[dict[str, str]],
) -> str:
"""Generate TMA init function code (creates descriptors in caller-provided host array).
TMA descriptors are stored in stack-local tma_descs[] array in launch_kernel.
cuLaunchKernel automatically handles __grid_constant__ params.
"""
if not desc_names:
return ""
func_args_parts = [f"uint64_t {arg}_ptr" for arg in tensor_args]
for arg in scalar_args:
if arg["type"] in ["int", "cutlass.Int32"]:
func_args_parts.append(f"int32_t {arg['name']}")
elif arg["type"] in ["float", "cutlass.Float32"]:
func_args_parts.append(f"float {arg['name']}")
else:
# Default to int32_t for scalars used in shape/stride math
func_args_parts.append(f"int32_t {arg['name']}")
func_args = ", ".join(func_args_parts)
num_descs = len(desc_names)
desc_inits = []
for idx, desc_name in enumerate(desc_names):
info = self.tma_desc_info[desc_name]
tensor_name, _ = tensor_arg_map[desc_name]
desc_inits.append(self._generate_tma_desc_init(desc_name, idx, tensor_name, info))
return CPP_TMA_INIT_FUNC_TEMPLATE.format(
func_args=func_args,
num_descs=num_descs,
desc_init_code="\n".join(desc_inits),
)
def _generate_tma_launch_init(
self, desc_names: list[str], tma_tensors: list[str], scalar_args: list[dict[str, str]], num_tma_descs: int
) -> str:
"""Generate TMA initialization code for launch function (host memory mode).
TMA descriptors stay on host. cuLaunchKernel copies them to param space
when kernel uses __grid_constant__ CUtensorMap parameter.
"""
if not desc_names:
return ""
# Generate tma_init call args (no device_ptr needed)
call_args_parts = [f"{arg}_ptr" for arg in tma_tensors] + [arg["name"] for arg in scalar_args]
tma_tensor_args = ", ".join(call_args_parts)
return CPP_TMA_LAUNCH_INIT_TEMPLATE.format(
num_tma_descs=num_tma_descs,
tma_tensor_args=tma_tensor_args,
)
# =========================================================================
# Kernel Code Generation
# =========================================================================
def _generate_kernel_init(self, kernel_idx: int, kernel_name: str, smem_size: int) -> str:
"""Generate kernel initialization code."""
return CPP_KERNEL_INIT_TEMPLATE.format(
kernel_idx=kernel_idx,
kernel_name=kernel_name,
smem_size=smem_size,
)
def _generate_kernel_launch(self, kernel_meta: dict, kernel_idx: int, all_desc_names: list[str]) -> str:
"""Generate single kernel launch code.
For __grid_constant__ CUtensorMap params:
- Pass CUtensorMap* directly (not CUtensorMap**)
- cuLaunchKernel copies 128 bytes to kernel param space
"""
call_args = kernel_meta["call_args"]
desc_names = kernel_meta["desc_names"]
function_info = kernel_meta["function_info"]
# Build kernel args
kernel_args = []
for arg_name, arg_type in call_args:
if "desc" in arg_name and arg_name in desc_names:
# For __grid_constant__ CUtensorMap: pass host pointer directly
# cuLaunchKernel will copy 128-byte CUtensorMap to param space
desc_idx = all_desc_names.index(arg_name)
kernel_args.append(f"&tma_descs[{desc_idx}]")
elif arg_type == "buffer":
kernel_args.append(f"&{arg_name}_ptr")
else:
kernel_args.append(f"&{arg_name}")
grid = function_info["grid_info"]
block = function_info["block_info"]
smem_size = function_info["dynamic_smem_buf"] or 0
return CPP_KERNEL_LAUNCH_TEMPLATE.format(
kernel_idx=kernel_idx,
kernel_name=kernel_meta["function_name"],
kernel_args=", ".join(kernel_args),
grid_x=self._cxx_expr(grid[0]),
grid_y=self._cxx_expr(grid[1]),
grid_z=self._cxx_expr(grid[2]),
block_x=self._cxx_expr(block[0]),
block_y=self._cxx_expr(block[1]),
block_z=self._cxx_expr(block[2]),
smem_size=smem_size,
)
# =========================================================================
# C++ Launcher Generation
# =========================================================================
def _generate_cpp_launcher(
self,
kernel_metadata_list: list[dict],
function_args: list[dict],
all_tma_tensors: list[str],
all_desc_names: list[str],
tensor_arg_map: dict[str, tuple[str, int]],
) -> str:
"""Generate complete C++ launcher code using templates.
TMA descriptors are stored on HOST memory in stack-local tma_descs[] array.
cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space
when kernel uses __grid_constant__ parameter.
"""
num_kernels = len(kernel_metadata_list)
num_tma_descs = max(len(all_desc_names), 1) # At least 1 to avoid zero-size array
# Generate kernel inits
kernel_inits = "\n".join(
self._generate_kernel_init(idx, km["function_name"], km["function_info"]["dynamic_smem_buf"] or 0)
for idx, km in enumerate(kernel_metadata_list)
)
# Generate TMA init function
scalar_args = [arg for arg in function_args if arg["type"] != "buffer"]
tma_init_func = self._generate_tma_init_func(all_desc_names, all_tma_tensors, tensor_arg_map, scalar_args)
# Generate launch function signature and get_ptr code
func_sig_parts = []
get_ptr_code = ""
for arg in function_args:
if arg["type"] == "buffer":
func_sig_parts.append(f"tvm::ffi::TensorView {arg['name']}")
get_ptr_code += f" uint64_t {arg['name']}_ptr = reinterpret_cast<uint64_t>({arg['name']}.data_ptr());\n"
elif arg["type"] in ["int", "cutlass.Int32"]:
func_sig_parts.append(f"int32_t {arg['name']}")
elif arg["type"] in ["float", "cutlass.Float32"]:
func_sig_parts.append(f"float {arg['name']}")
else:
func_sig_parts.append(f"int32_t {arg['name']}")
# Generate TMA init in launch
tma_init_in_launch = self._generate_tma_launch_init(all_desc_names, all_tma_tensors, scalar_args, num_tma_descs)
# Generate kernel launches
kernel_launches = "\n".join(self._generate_kernel_launch(km, idx, all_desc_names) for idx, km in enumerate(kernel_metadata_list))
return CPP_LAUNCHER_TEMPLATE.format(
num_kernels=num_kernels,
num_tma_descs=num_tma_descs,
kernel_inits=kernel_inits,
tma_init_func=tma_init_func,
launch_func_sig=", ".join(func_sig_parts),
get_ptr_code=get_ptr_code,
tma_init_in_launch=tma_init_in_launch,
kernel_launches=kernel_launches,
)
# =========================================================================
# Python Wrapper Generation
# =========================================================================
def _generate_cubin_gen_code(
self,
kernel_metadata_list: list[dict],
buffer_args: list[str],
all_desc_names: list[str],
lib_code: str = "",
) -> str:
"""Generate cubin generation code for Python wrapper using templates.
Args:
lib_code: The CuTeDSL kernel definitions (@cute.kernel decorated functions).
This will be embedded inside _generate_cubin_if_needed to enable
lazy loading of cutlass/cute modules.
"""
# Build unified wrapper parameters
wrapper_params_union = []
for kernel_meta in kernel_metadata_list:
for arg_name, _ in kernel_meta["call_args"]:
if arg_name not in wrapper_params_union:
wrapper_params_union.append(arg_name)
# Build inner args for cute.compile
inner_args = []
fake_inner_args = []
for arg_name in wrapper_params_union:
if arg_name in buffer_args:
inner_args.append(f"{arg_name}_")
fake_inner_args.append(f"__fake_{arg_name}__")
elif arg_name in all_desc_names:
continue
else:
inner_args.append(arg_name)
fake_inner_args.append(arg_name)
if all_desc_names:
inner_args.append("__fake_tensor__")
fake_inner_args.append("__fake_tensor__")
fake_inner_args.append("__fake_stream__")
# Generate TMA init code
tma_init_code = ""
if all_desc_names:
tma_init_lines = [" # Create dummy TMA atoms for compilation"]
tma_init_lines.extend(CUBIN_TMA_ATOM_INIT_TEMPLATE.format(desc_name=desc_name) for desc_name in all_desc_names)
tma_init_code = "\n".join(tma_init_lines) + "\n"
# Generate kernel launch calls
kernel_launches = "\n".join(
CUBIN_KERNEL_LAUNCH_TEMPLATE.format(
function_name=km["function_name"],
call_args=", ".join(arg[0] if arg[0] not in buffer_args else f"{arg[0]}_" for arg in km["call_args"]),
grid_x=self._pythonic_expr(km["function_info"]["grid_info"][0]),
grid_y=self._pythonic_expr(km["function_info"]["grid_info"][1]),
grid_z=self._pythonic_expr(km["function_info"]["grid_info"][2]),
block_x=self._pythonic_expr(km["function_info"]["block_info"][0]),
block_y=self._pythonic_expr(km["function_info"]["block_info"][1]),
block_z=self._pythonic_expr(km["function_info"]["block_info"][2]),
smem_size=km["function_info"]["dynamic_smem_buf"] or 0,
)
for km in kernel_metadata_list
)
# Generate fake tensor creation code
# IMPORTANT: Generate fake tensors based on the *union* of parameters actually
# passed to cute.compile (wrapper_params_union).
#
# In multi-kernel cases, a tensor may appear both as a TMA descriptor
# (e.g. Output_partial_desc) for one kernel and as a plain tensor argument
# (e.g. Output_partial_) for another kernel. Skipping fake tensor creation
# just because a matching "{arg}_desc" exists is a correctness bug and
# results in undefined names like "__fake_Output_partial__".
fake_tensor_code = "\n".join(
CUBIN_FAKE_TENSOR_TEMPLATE.format(arg_name=arg_name) for arg_name in wrapper_params_union if arg_name in buffer_args
)
if fake_tensor_code:
fake_tensor_code += "\n"
# Generate fake TMA tensor code
fake_tma_tensor_code = ""
if all_desc_names:
fake_tma_tensor_code = (
" __fake_tensor__ = make_fake_compact_tensor(cutlass.Int32, (32, 32), stride_order=(1, 0), assumed_align=16)\n"
)
# Indent lib_code to be inside the function
indented_lib_code = "\n".join(" " + line if line.strip() else line for line in lib_code.split("\n")) if lib_code else ""
return CUBIN_GEN_CODE_TEMPLATE.format(
lib_code=indented_lib_code,
wrapper_args=", ".join(inner_args + ["stream: CUstream"]),
tma_init_code=tma_init_code,
kernel_launches=kernel_launches,
fake_tensor_code=fake_tensor_code,
fake_tma_tensor_code=fake_tma_tensor_code,
compile_args=", ".join(fake_inner_args),
primary_name=kernel_metadata_list[0]["function_name"],
)
def _generate_python_wrapper(
self,
function_args: list[dict],
cubin_gen_code: str,
cubin_gen_params: str,
) -> str:
"""Generate Python wrapper code."""
# Build function parameters
call_func_params = ", ".join(arg["name"] for arg in function_args)
launcher_call_args = ", ".join(arg["name"] for arg in function_args)
return PYTHON_HOST_FUNC_TEMPLATE.format(
cubin_gen_params=cubin_gen_params,
cubin_gen_code=cubin_gen_code,
launcher_lib_name=self._launcher_lib_name,
call_func_params=call_func_params,
cubin_gen_call_args=cubin_gen_params,
arg_prep_code="",
launcher_call_args=launcher_call_args,
)
# =========================================================================
# TMA Descriptor Processing
# =========================================================================
def _process_tma_descriptors(self, desc_names: list[str]) -> tuple[list[str], dict[str, tuple[str, int]]]:
"""Process TMA descriptors and return tensor args and mapping.
Returns:
Tuple of (tensor_args, tensor_arg_map)
"""
if not hasattr(self, "tma_desc_info") or not desc_names:
return [], {}
tensor_args = []
tensor_arg_map = {}
for desc_name in desc_names:
info = self.tma_desc_info[desc_name]
# Extract the base buffer variable name (must be a Var, not arbitrary expression)
global_addr = info["globalAddress"]
if not isinstance(global_addr, tvm.tir.Var):
raise ValueError(f"TMA globalAddress must be a buffer Var, got {type(global_addr)}: {global_addr}")
tensor_name = global_addr.name
if tensor_name not in tensor_args:
tensor_args.append(tensor_name)
tensor_arg_map[desc_name] = (tensor_name, len(tensor_args) - 1)
else:
tensor_arg_map[desc_name] = (tensor_name, tensor_args.index(tensor_name))
return tensor_args, tensor_arg_map
def generate_tma_descriptor_args(
self,
desc_name_map: dict[str, str],
desc_name_var_map: dict[str, tvm.tir.Var],
tma_desc_code_map: dict[str, str],
) -> list[str]:
"""Generate TMA descriptor information for C++ code generation.
Returns:
List of descriptor variable names in the order they were processed.
"""
if self.tma_descriptor_args is None:
return []
if not hasattr(self, "tma_desc_info"):
self.tma_desc_info = {}
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr)
desc_names_ordered = []
for params in parsed_params:
handle_name = params.handle_name
if handle_name in tma_desc_code_map:
continue
desc_var = desc_name_var_map[handle_name]
args = self.tma_descriptor_args[desc_var]
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
tensor_rank = int(tensor_rank)
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank : 2 * tensor_rank]
if not params.is_img2col:
box_dim = remaining_args[2 * tensor_rank : 3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank : 4 * tensor_rank]
self.tma_desc_info[handle_name] = {
"desc_var": desc_var,
"is_img2col": False,
"dtype": params.dtype,
"tensor_rank": params.tensor_rank,
"globalAddress": params.global_address,
"global_dim": global_dim,
"global_stride": global_stride,
"box_dim": box_dim,
"element_strides": element_strides,
"interleave": params.interleave,
"swizzle": params.swizzle,
"l2Promotion": params.l2_promotion,
"oobFill": params.oob_fill,
}
else:
element_strides = remaining_args[2 * tensor_rank : 3 * tensor_rank]
self.tma_desc_info[handle_name] = {
"desc_var": desc_var,
"is_img2col": True,
"dtype": params.dtype,
"tensor_rank": params.tensor_rank,
"globalAddress": params.global_address,
"global_dim": global_dim,
"global_stride": global_stride,
"element_strides": element_strides,
"lower_corner": params.lower_corner,
"upper_corner": params.upper_corner,
"smem_box_channel": params.smem_box_channel,
"smem_box_pixel": params.smem_box_pixel,
"interleave": params.interleave,
"swizzle": params.swizzle,
"l2Promotion": params.l2_promotion,
"oobFill": params.oob_fill,
}
tma_desc_code_map[handle_name] = ""
desc_names_ordered.append(handle_name)
return desc_names_ordered
# =========================================================================
# Main Entry Points
# =========================================================================
def create_dispatch_func(self, code, function_informations):
"""Create dispatch function - always use C++ launcher."""
return self.create_dispatch_func_cpp_launcher(code, function_informations)
def create_dispatch_func_cpp_launcher(self, code, function_informations):
"""Create dispatch function using C++ launcher."""
function_args, buffer_args = self._collect_function_args()
# Process each kernel and collect metadata
kernel_metadata = []
all_desc_names_union = []
all_tma_tensors_union = []
for function_name, function_info in function_informations.items():
declaration = extract_python_func_declaration(code, function_name)
desc_name_map: dict[str, str] = {}
desc_name_var_map: dict[str, tvm.tir.Var] = {}
call_args = self._extract_func_call_args(
declaration,
function_args,
function_info["function_params"],
desc_name_map,
desc_name_var_map,
)
tma_desc_code_map = {}
desc_names = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map, tma_desc_code_map)
tma_tensor_args, _ = self._process_tma_descriptors(desc_names)
kernel_metadata.append(
{
"function_name": function_name,
"function_info": function_info,
"call_args": call_args,
"desc_names": desc_names,
"tma_tensor_args": tma_tensor_args,
"desc_name_map": desc_name_map,
}
)
for desc in desc_names:
if desc not in all_desc_names_union:
all_desc_names_union.append(desc)
for t in tma_tensor_args:
if t not in all_tma_tensors_union:
all_tma_tensors_union.append(t)
# Process all TMA descriptors
all_tma_tensors, tensor_arg_map = self._process_tma_descriptors(all_desc_names_union)
# Generate C++ launcher
launcher_cpp_code = self._generate_cpp_launcher(
kernel_metadata, function_args, all_tma_tensors, all_desc_names_union, tensor_arg_map
)
self.launcher_cpp_code = launcher_cpp_code
# Use a deterministic name so that:
# 1) the generated kernel.py can always locate the launcher in the same directory
# 2) KernelCache can store it under a stable filename
self._launcher_lib_name = "launcher_lib.so"
self.launcher_lib_name = self._launcher_lib_name
# Generate cubin generation code (includes lib_code with @cute.kernel definitions)
cubin_gen_code = self._generate_cubin_gen_code(
kernel_metadata, buffer_args, all_desc_names_union, lib_code=getattr(self, "lib_code", "")
)
# Generate Python wrapper
buffer_names = [arg["name"] for arg in function_args if arg["type"] == "buffer"]
# Cubin generation may reference scalar args (e.g., dynamic symbols like m/n/k)
# inside `kernel_wrapper` and `cute.compile(...)`. They must be visible in
# `_generate_cubin_if_needed(...)` scope, so include them in its signature.
scalar_names = [arg["name"] for arg in function_args if arg["type"] != "buffer"]
cubin_gen_params = ", ".join(buffer_names + scalar_names)
python_wrapper = self._generate_python_wrapper(function_args, cubin_gen_code, cubin_gen_params)
return python_wrapper
def get_launcher_cpp_code(self) -> str:
"""Get the generated C++ launcher code."""
return getattr(self, "launcher_cpp_code", "")
def update_lib_code(self, code: str):
"""Update the library code with the given code string."""
self.lib_code = code
function_informations = {}
for function_name in self.function_names:
if (function_name not in self.block_info) or (function_name not in self.grid_info):
continue
assert function_name in self.device_mod, f"Function {function_name} not found in device module"
device_func = self.device_mod[function_name]
kernel_params_cnt = len(device_func.params)
function_params: list[str] = None
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params
if isinstance(node, tvm.tir.Call):
if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
return
args = node.args
if not args or args[0] != fn:
return
if len(args) < 1 + param_cnt:
raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters")
function_params = args[1 : 1 + param_cnt]
post_order_visit(self.host_func.body, visitor)
assert function_params is not None, "function_params should not be None"
function_informations[function_name] = {
"function_name": function_name,
"block_info": self.block_info[function_name],
"grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
"function_params": function_params,
}
self.host_func = self.create_dispatch_func(code, function_informations)
return self.lib_code
...@@ -76,7 +76,9 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -76,7 +76,9 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod) self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod) self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(device_kernel_source) wrapper_result = self.wrapper.wrap(device_kernel_source)
self.host_func = wrapper_result["host_func"]
self.function_names = wrapper_result["function_names"]
self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose) self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.device_kernel_source) self.lib_generator.update_lib_code(self.device_kernel_source)
......
...@@ -273,7 +273,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -273,7 +273,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
Casts are noise in generated Python code - Python is dynamically typed. Casts are noise in generated Python code - Python is dynamically typed.
""" """
return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True) return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True, floor_div_op="//")
def create_dispatch_func(self, code, function_informations): def create_dispatch_func(self, code, function_informations):
"""Generate Python dispatch function that launches multiple CUDA kernels. """Generate Python dispatch function that launches multiple CUDA kernels.
......
...@@ -38,6 +38,53 @@ def match_declare_kernel(source: str, annotation: str = "__global__") -> int: ...@@ -38,6 +38,53 @@ def match_declare_kernel(source: str, annotation: str = "__global__") -> int:
raise ValueError("No global kernel found in the source code") raise ValueError("No global kernel found in the source code")
def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int:
# Match decorator followed by function definition across lines
# \s+ allows any whitespace including newlines between decorator and def
pattern = r"@cute\.kernel\s+def\s+(\w+)"
matched = re.search(pattern, source, re.MULTILINE)
if matched:
# Find the position of the opening parenthesis after the function name
# matched.start(1) gives position of function name
func_name_pos = matched.start(1)
# Find the '(' after function name
paren_pos = source.find("(", func_name_pos)
if paren_pos != -1:
return paren_pos
raise ValueError("No global kernel found in the source code")
def extract_python_func_declaration(source: str, func_name: str) -> str:
"""Extract the full Python function declaration from decorator to colon.
Args:
source: Source code containing the function
func_name: Name of the function to extract (can include '(' suffix)
Returns:
The function declaration from 'def' to ':', including parameters
Example:
For code:
@cute.kernel
def kernel(arg1: cute.Tensor, arg2: int):
...
Returns: "def kernel(arg1: cute.Tensor, arg2: int)"
"""
# Remove '(' suffix if present
if func_name.endswith("("):
func_name = func_name[:-1]
# Match from def to the closing ) followed by :
# This handles multi-line function signatures
pattern = rf"def\s+{re.escape(func_name)}\s*\([^)]*\)"
matched = re.search(pattern, source, re.DOTALL)
if matched:
return matched.group(0)
raise ValueError(f"No function declaration found for {func_name}")
def match_declare_kernel_cpu(source: str, annotation: str = "int32_t") -> int: def match_declare_kernel_cpu(source: str, annotation: str = "int32_t") -> int:
pattern = r"int32_t\s+\w+" pattern = r"int32_t\s+\w+"
for line in source.split("\n"): for line in source.split("\n"):
...@@ -64,6 +111,10 @@ def is_metal_target(target: Target) -> bool: ...@@ -64,6 +111,10 @@ def is_metal_target(target: Target) -> bool:
return target.kind.name == "metal" return target.kind.name == "metal"
def is_cutedsl_target(target: Target) -> bool:
return target.kind.name == "cuda" and "cutedsl" in target.keys
def get_annotated_mod( def get_annotated_mod(
func_or_mod: tir.PrimFunc | tvm.IRModule, func_or_mod: tir.PrimFunc | tvm.IRModule,
target: str | Target = "auto", target: str | Target = "auto",
...@@ -102,7 +153,9 @@ def get_annotated_mod( ...@@ -102,7 +153,9 @@ def get_annotated_mod(
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False) -> str: def pythonic_expr(
expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False, floor_div_op: str = "/"
) -> str:
""" """
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
...@@ -110,6 +163,10 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non ...@@ -110,6 +163,10 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
expr: The TVM PrimExpr to convert. expr: The TVM PrimExpr to convert.
dtype_map: A dictionary mapping data types to their string representations. dtype_map: A dictionary mapping data types to their string representations.
ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast. ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast.
floor_div_op: Operator to use for tvm.tir.FloorDiv. Default '/' preserves prior
behavior (suitable for generating C/C++ expressions). For generating
Python code where integer division is required (e.g. grid/block),
pass '//' explicitly.
Returns: Returns:
A string representation of the expression. A string representation of the expression.
""" """
...@@ -180,7 +237,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non ...@@ -180,7 +237,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
): ):
op_map = { op_map = {
tvm.tir.Mul: "*", tvm.tir.Mul: "*",
tvm.tir.FloorDiv: "/", tvm.tir.FloorDiv: floor_div_op,
tvm.tir.Add: "+", tvm.tir.Add: "+",
tvm.tir.Sub: "-", tvm.tir.Sub: "-",
tvm.tir.FloorMod: "%", tvm.tir.FloorMod: "%",
......
...@@ -4,8 +4,10 @@ from tilelang import tvm as tvm ...@@ -4,8 +4,10 @@ from tilelang import tvm as tvm
from typing import Any from typing import Any
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import ( from .utils import (
is_metal_target, is_metal_target,
is_cutedsl_target,
match_declare_kernel, match_declare_kernel,
match_declare_kernel_cpu, match_declare_kernel_cpu,
is_cuda_target, is_cuda_target,
...@@ -198,7 +200,9 @@ class TLCUDASourceWrapper: ...@@ -198,7 +200,9 @@ class TLCUDASourceWrapper:
self.lib_code: str | None = self.update_lib_code(source) self.lib_code: str | None = self.update_lib_code(source)
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP) # This wrapper generates C/CUDA source. C/C++ integer division uses '/',
# and '//' is not a valid operator in C/C++.
return pythonic_expr(expr, self._TYPE_MAP, floor_div_op="/")
def _lookup_type(self, dtype: str | Any) -> str: def _lookup_type(self, dtype: str | Any) -> str:
key = dtype if isinstance(dtype, str) else str(dtype) key = dtype if isinstance(dtype, str) else str(dtype)
...@@ -326,9 +330,9 @@ class TLCUDASourceWrapper: ...@@ -326,9 +330,9 @@ class TLCUDASourceWrapper:
return init_l2_persistent_map return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = "" tma_descriptor_init = ""
if self.tma_descriptor_args is None: if self.tma_descriptor_args is None:
return tma_descripter_init return tma_descriptor_init
# Parse TMA descriptor arguments using the common utility # Parse TMA descriptor arguments using the common utility
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr)
...@@ -336,7 +340,7 @@ class TLCUDASourceWrapper: ...@@ -336,7 +340,7 @@ class TLCUDASourceWrapper:
# Generate C++ code from parsed parameters # Generate C++ code from parsed parameters
for params in parsed_params: for params in parsed_params:
if not params.is_img2col: if not params.is_img2col:
tma_descripter_init += TMA_DESC_INIT_FUNC.format( tma_descriptor_init += TMA_DESC_INIT_FUNC.format(
params.handle_name, params.handle_name,
params.dtype, params.dtype,
params.tensor_rank, params.tensor_rank,
...@@ -351,7 +355,7 @@ class TLCUDASourceWrapper: ...@@ -351,7 +355,7 @@ class TLCUDASourceWrapper:
params.oob_fill, params.oob_fill,
) )
else: else:
tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC.format(
params.handle_name, params.handle_name,
params.dtype, params.dtype,
params.tensor_rank, params.tensor_rank,
...@@ -369,7 +373,7 @@ class TLCUDASourceWrapper: ...@@ -369,7 +373,7 @@ class TLCUDASourceWrapper:
params.oob_fill, params.oob_fill,
) )
return tma_descripter_init return tma_descriptor_init
def parse_source_information(self): def parse_source_information(self):
if self.device_mod is None or self.host_mod is None: if self.device_mod is None or self.host_mod is None:
...@@ -817,6 +821,9 @@ class TLMetalSourceWrapper: ...@@ -817,6 +821,9 @@ class TLMetalSourceWrapper:
return self.lib_code return self.lib_code
# TLCuTeDSLSourceWrapper has been moved to tilelang.jit.adapter.cutedsl.wrapper
class TLWrapper(BaseWrapper): class TLWrapper(BaseWrapper):
""" """
A wrapper class for the TileLang backend. A wrapper class for the TileLang backend.
...@@ -875,9 +882,13 @@ class TLPyWrapper(TLWrapper): ...@@ -875,9 +882,13 @@ class TLPyWrapper(TLWrapper):
def __init__(self, target: Target): def __init__(self, target: Target):
super().__init__(target) super().__init__(target)
def wrap(self, c_source: str): def wrap(self, py_source: str):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first." # assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target): if is_cutedsl_target(self.target):
from tilelang.jit.adapter.cutedsl import TLCuTeDSLSourceWrapper
wrapper_class = TLCuTeDSLSourceWrapper
elif is_cuda_target(self.target):
from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper
wrapper_class = TLNVRTCSourceWrapper wrapper_class = TLNVRTCSourceWrapper
...@@ -885,10 +896,17 @@ class TLPyWrapper(TLWrapper): ...@@ -885,10 +896,17 @@ class TLPyWrapper(TLWrapper):
raise ValueError(f"Unsupported target for NVRTC backend: {self.target}") raise ValueError(f"Unsupported target for NVRTC backend: {self.target}")
wrapper = wrapper_class( wrapper = wrapper_class(
scheduled_ir_module=self.scheduled_ir_module, scheduled_ir_module=self.scheduled_ir_module,
source=c_source, source=py_source,
target=self.target, target=self.target,
device_mod=self.device_mod, device_mod=self.device_mod,
host_mod=self.host_mod, host_mod=self.host_mod,
pass_configs=self.pass_configs, pass_configs=self.pass_configs,
) )
return wrapper.host_func, wrapper.function_names return {
"host_func": getattr(wrapper, "host_func", None),
"function_names": getattr(wrapper, "function_names", None),
"tma_cpp_init_code": getattr(wrapper, "tma_cpp_init_code", None),
"tma_lib_name": getattr(wrapper, "tma_lib_name", None),
"launcher_cpp_code": getattr(wrapper, "launcher_cpp_code", None),
"launcher_lib_name": getattr(wrapper, "launcher_lib_name", None),
}
...@@ -3,6 +3,8 @@ from __future__ import annotations ...@@ -3,6 +3,8 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from tvm.target import Target from tvm.target import Target
from tilelang.jit.adapter.utils import is_cutedsl_target
from tilelang.env import env as _env
# Canonical names for execution backends used internally # Canonical names for execution backends used internally
_CANONICAL_MAP = { _CANONICAL_MAP = {
...@@ -30,7 +32,9 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T ...@@ -30,7 +32,9 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
""" """
kind = _target_kind(target) kind = _target_kind(target)
if kind == "cuda": if is_cutedsl_target(target):
return ["cutedsl"]
elif kind == "cuda":
allowed = ["tvm_ffi", "nvrtc", "cython", "ctypes"] allowed = ["tvm_ffi", "nvrtc", "cython", "ctypes"]
elif kind == "hip": elif kind == "hip":
allowed = ["tvm_ffi", "cython", "ctypes"] allowed = ["tvm_ffi", "cython", "ctypes"]
...@@ -72,8 +76,26 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: ...@@ -72,8 +76,26 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
allowed_all = allowed_backends_for_target(target, include_unavailable=True) allowed_all = allowed_backends_for_target(target, include_unavailable=True)
allowed_avail = allowed_backends_for_target(target, include_unavailable=False) allowed_avail = allowed_backends_for_target(target, include_unavailable=False)
def _require_gemm_v1_for_cutedsl():
if not _env.use_gemm_v1():
raise ValueError(
"CuTeDSL backend requires GEMM v1. Please set environment variable TILELANG_USE_GEMM_V1=1 before importing tilelang."
)
# Fail fast with a clear error if CuTeDSL dependencies are missing or incompatible.
try:
from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available # lazy
check_cutedsl_available()
except ImportError as e:
# Keep resolve_execution_backend's error semantics (ValueError) while
# preserving the actionable ImportError message.
raise ValueError(str(e)) from e
# Default selection for auto/None # Default selection for auto/None
if req in (None, "auto"): if req in (None, "auto"):
if is_cutedsl_target(target):
_require_gemm_v1_for_cutedsl()
return "cutedsl"
kind = _target_kind(target) kind = _target_kind(target)
if kind == "cuda": if kind == "cuda":
choice = "tvm_ffi" choice = "tvm_ffi"
...@@ -100,4 +122,8 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: ...@@ -100,4 +122,8 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
f"Try one of: {_format_options(allowed_avail)}." f"Try one of: {_format_options(allowed_avail)}."
) )
# CuTeDSL requires GEMM v1
if req == "cutedsl":
_require_gemm_v1_for_cutedsl()
return req return req
...@@ -7,7 +7,7 @@ try: ...@@ -7,7 +7,7 @@ try:
except ImportError: # Python < 3.10 except ImportError: # Python < 3.10
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from tilelang.jit.adapter.utils import is_metal_target, is_cuda_target from tilelang.jit.adapter.utils import is_cutedsl_target, is_metal_target, is_cuda_target
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
...@@ -15,7 +15,14 @@ import tilelang ...@@ -15,7 +15,14 @@ import tilelang
from tilelang import tvm from tilelang import tvm
from tilelang import env from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, TVMFFIKernelAdapter, MetalKernelAdapter from tilelang.jit.adapter import (
BaseKernelAdapter,
CtypesKernelAdapter,
CythonKernelAdapter,
CuTeDSLKernelAdapter,
TVMFFIKernelAdapter,
MetalKernelAdapter,
)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc from tilelang.contrib import nvcc as tl_nvcc
...@@ -57,7 +64,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -57,7 +64,7 @@ class JITKernel(Generic[_P, _T]):
self, self,
func: PrimFunc = None, func: PrimFunc = None,
out_idx: list[int] | int = None, out_idx: list[int] | int = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
verbose: bool = False, verbose: bool = False,
...@@ -74,7 +81,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -74,7 +81,7 @@ class JITKernel(Generic[_P, _T]):
The TileLang TIR function to compile and wrap. The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional execution_backend : Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Execution backend to use for kernel execution.
target : Union[str, Target], optional target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto"). Compilation target, either as a string or a TVM Target object (default: "auto").
...@@ -109,6 +116,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -109,6 +116,7 @@ class JITKernel(Generic[_P, _T]):
"cython", "cython",
"nvrtc", "nvrtc",
"torch", "torch",
"cutedsl",
], f"Invalid execution backend. {execution_backend}" ], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython": if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
...@@ -316,6 +324,20 @@ class JITKernel(Generic[_P, _T]): ...@@ -316,6 +324,20 @@ class JITKernel(Generic[_P, _T]):
# pass_configs=pass_configs, # pass_configs=pass_configs,
# compile_flags=compile_flags, # compile_flags=compile_flags,
) )
elif execution_backend == "cutedsl":
assert is_cutedsl_target(target)
adapter = CuTeDSLKernelAdapter(
params=artifact.params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
device_kernel_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
else: else:
# Handle invalid backend. # Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}") raise ValueError(f"Invalid execution backend: {execution_backend}")
...@@ -387,6 +409,18 @@ class JITKernel(Generic[_P, _T]): ...@@ -387,6 +409,18 @@ class JITKernel(Generic[_P, _T]):
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
) )
elif execution_backend == "cutedsl":
adapter = CuTeDSLKernelAdapter.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
else: else:
# Handle invalid backend. # Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}") raise ValueError(f"Invalid execution backend: {execution_backend}")
...@@ -437,7 +471,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -437,7 +471,7 @@ class JITKernel(Generic[_P, _T]):
str str
The source code of the compiled kernel function. The source code of the compiled kernel function.
""" """
if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}: if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi", "cutedsl"}:
return self.adapter.get_kernel_source(kernel_only=kernel_only) return self.adapter.get_kernel_source(kernel_only=kernel_only)
return self.artifact.kernel_source return self.artifact.kernel_source
...@@ -445,7 +479,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -445,7 +479,7 @@ class JITKernel(Generic[_P, _T]):
""" """
Returns the source code of the host function. Returns the source code of the host function.
""" """
if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}: if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi", "cutedsl"}:
return self.adapter.get_host_source() return self.adapter.get_host_source()
assert self.artifact.host_mod is not None, "host_mod is not available" assert self.artifact.host_mod is not None, "host_mod is not available"
return str(self.artifact.host_mod) return str(self.artifact.host_mod)
......
...@@ -15,6 +15,7 @@ SUPPORTED_TARGETS: dict[str, str] = { ...@@ -15,6 +15,7 @@ SUPPORTED_TARGETS: dict[str, str] = {
"llvm": "LLVM CPU target (accepts standard TVM LLVM options).", "llvm": "LLVM CPU target (accepts standard TVM LLVM options).",
"webgpu": "WebGPU target for browser/WebGPU runtimes.", "webgpu": "WebGPU target for browser/WebGPU runtimes.",
"c": "C source backend.", "c": "C source backend.",
"cutedsl": "CuTe DSL GPU target.",
} }
...@@ -95,6 +96,14 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj ...@@ -95,6 +96,14 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj
return_var = "metal" return_var = "metal"
else: else:
raise ValueError("No CUDA or HIP or MPS available on this system.") raise ValueError("No CUDA or HIP or MPS available on this system.")
elif isinstance(target, str) and target.startswith("cutedsl"):
cuda_target_str = target.replace("cutedsl", "cuda", 1)
temp_target = Target(cuda_target_str)
target_dict = dict(temp_target.export())
target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"]
return_var = Target(target_dict)
else: else:
# Validate the target if it's not "auto" # Validate the target if it's not "auto"
if isinstance(target, Target): if isinstance(target, Target):
...@@ -115,6 +124,8 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj ...@@ -115,6 +124,8 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj
else: else:
raise AssertionError(f"Target {target} is not supported") raise AssertionError(f"Target {target} is not supported")
if isinstance(return_var, Target):
return return_var
if return_object: if return_object:
if isinstance(return_var, Target): if isinstance(return_var, Target):
return return_var return return_var
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment