Unverified Commit 7248a810 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

feat(cutedsl): add CuTeDSL backend (#1421)



* feat: CuTeDSL backend

* fix: clang-tidy

* fix: clang-format

* fix: ci

* fix: revert example gemm fp8

* fix: remove duplicate code

* fix: switch-case

* fix: fp16 silence

* fix: TVM IR print

* fix: useless tir

* fix: clang-format

* fix: remove tilelang/contrib/cutedsl/.gitignore

* fix: use hexfloat

* fix: gsym guard

* fix: unknown storage sync type

* fix: string literal

* fix: add args guard

* fix: name hint dedup

* fix: better find_kernel_by_pattern

* fix: set libpath for from_database path

* fix: guard buffer.strides

* fix: from guard

* fix: eviction guard

* fix: use thread local tma descs

* fix: ruff

* fix: drop tma_init_cpp

* fix: exc_info

* fix: negative unmatch early return

* fix: rename postproc func and add test

* fix: handle fast math according to pass config

* fix: dyn_sym parse

* fix: wrap_forward

* fix: use tvm_ffi.libinfo instead of cli

* fix: keep signature

* fix: C++ string safety

* fix: mark tma_store_add as unsupported

* fix: tvm version

* resolve ldsm and cpasync issues.

* fix: minor fixes

* fix: parse signature using ast

* fix: guard global_addr

* fix: create tempfile only when necessary

* fix: use logger.execption for exceptions

* fix: guard lib_path and host_func

* fix: remove tma_cpp_init and add timeout for cpp compile

* add timeout for mbarrier_wait.

* fix: _load_kernel_from_disk signature

* resolve codegen issues.

* fix: logger.exception

* add comment for div_by=1

* merge

* fix: reserve cutlass,cute,tl

* fix: guard tma_store

* fix: allow int64 offset in make_tensor_at_offset

* fix: guard barrier

* fix: add comments for div_by=16

* fix: div_by=1 issue

* delete div_by when offset is 0

* use tl.make_tensor when offset is 0

* fix: explicitly check cutedsl target

* fix: use param.torch_dtype()

---------
Co-authored-by: default avataryuxic <yuxic@nvidia.com>
Co-authored-by: default avatarYong <yong@local>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent a6f59f31
"""
Reduce operations for CuTeDSL backend.
Based on tl_templates/cuda/reduce.h
"""
from __future__ import annotations
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Int32, Float32
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import nvvm
from cutlass.cute.arch.nvvm_wrappers import shuffle_sync_op
@dsl_user_op
def min(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32:
return Float32(
nvvm.fmin(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
@dsl_user_op
def max(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32:
return Float32(
nvvm.fmax(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
class SumOp:
"""Sum reduction operator"""
@staticmethod
def __call__(x, y):
return x + y
class MaxOp:
"""Max reduction operator"""
@staticmethod
def __call__(x, y):
return max(x, y)
class MinOp:
"""Min reduction operator"""
@staticmethod
def __call__(x, y):
# Use cutlass.min which is JIT-friendly
return min(x, y)
class BitAndOp:
"""Bitwise AND reduction operator"""
@staticmethod
def __call__(x, y):
return x & y
class BitOrOp:
"""Bitwise OR reduction operator"""
@staticmethod
def __call__(x, y):
return x | y
class BitXorOp:
"""Bitwise XOR reduction operator"""
@staticmethod
def __call__(x, y):
return x ^ y
def bar_sync(barrier_id, number_of_threads):
cute.arch.barrier(barrier_id=barrier_id, number_of_threads=number_of_threads)
def bar_sync_ptx(barrier_id, number_of_threads):
from cutlass._mlir.dialects import llvm
llvm.inline_asm(
None,
[Int32(barrier_id).ir_value(), Int32(number_of_threads).ir_value()],
"bar.sync $0, $1;",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def AllReduce(reducer, threads, scale, thread_offset, all_threads=None):
"""
AllReduce operation implementing warp/block-level reduction.
Based on tl::AllReduce from reduce.h
Args:
reducer: Reducer operator class (SumOp, MaxOp, etc.)
threads: Number of threads participating in reduction
scale: Reduction scale factor
thread_offset: Thread ID offset
all_threads: Total number of threads in block
Returns:
A callable object with run() and run_hopper() methods
"""
class AllReduceInstance:
def __init__(self, reducer, threads, scale, thread_offset: cutlass.Constexpr[int], all_threads: cutlass.Constexpr[int]):
self.reducer = reducer
self.threads = threads
self.scale = scale
self.thread_offset = thread_offset
self.all_threads = all_threads if all_threads is not None else threads
def run(self, x, red_buf: cute.Pointer = None):
"""
Perform all-reduce across threads.
Based on tl::AllReduce<...>::run from reduce.h
"""
offset = self.threads // 2
if offset >= 32:
# Use shared memory for large thread counts
cute.arch.sync_threads()
tidx, _, _ = cute.arch.thread_idx()
cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x
cute.arch.sync_threads()
x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0])
else:
# Use warp shuffle for small thread counts
# Use the pre-existing shuffle_sync_op with butterfly (XOR) mode
other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly)
x = self.reducer()(x, other)
return (
x
if offset == self.scale
else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf)
)
def run_hopper(self, x, red_buf: cute.Pointer = None):
"""
Perform all-reduce on Hopper architecture using bar.sync.
Based on tl::AllReduce<...>::run_hopper from reduce.h
"""
offset = self.threads // 2
tidx, _, _ = cute.arch.thread_idx()
if offset >= 32:
# Use inlined asm for bar.sync to avoid instruction reordering
bar_sync_ptx(1, self.all_threads)
cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x
bar_sync_ptx(2, self.all_threads)
x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0])
else:
# Use warp shuffle for small thread counts
# Use the pre-existing shuffle_sync_op with butterfly (XOR) mode
other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly)
x = self.reducer()(x, other)
return (
x
if offset == self.scale
else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run_hopper(x, red_buf)
)
return AllReduceInstance(reducer, threads, scale, thread_offset, all_threads)
import cutlass.cute as cute
from cutlass.cute.typing import Constexpr
from dataclasses import dataclass
@dataclass(frozen=True)
class dim3:
x: int
y: int
z: int
def ThreadIdx() -> dim3:
return dim3(*cute.arch.thread_idx())
def BlockIdx() -> dim3:
return dim3(*cute.arch.block_idx())
def GridDim() -> dim3:
return dim3(*cute.arch.grid_dim())
@cute.jit
def rasterization2DRow(panel_width: Constexpr[int]) -> dim3:
blockIdx = BlockIdx()
gridDim = GridDim()
block_idx = blockIdx.x + blockIdx.y * gridDim.x
grid_size = gridDim.x * gridDim.y
panel_size = panel_width * gridDim.x
panel_offset = block_idx % panel_size
panel_idx = block_idx // panel_size
total_panel = cute.ceil_div(grid_size, panel_size)
stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.x
col_idx = (gridDim.x - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride)
row_idx = panel_offset % stride + panel_idx * panel_width
return dim3(col_idx, row_idx, blockIdx.z)
@cute.jit
def rasterization2DColumn(panel_width: Constexpr[int]) -> dim3:
blockIdx = BlockIdx()
gridDim = GridDim()
block_idx = blockIdx.x + blockIdx.y * gridDim.x
grid_size = gridDim.x * gridDim.y
panel_size = panel_width * gridDim.y
panel_offset = block_idx % panel_size
panel_idx = block_idx // panel_size
total_panel = cute.ceil_div(grid_size, panel_size)
stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.y
row_idx = (gridDim.y - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride)
col_idx = panel_offset % stride + panel_idx * panel_width
return dim3(col_idx, row_idx, blockIdx.z)
......@@ -197,7 +197,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda")
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
else:
......@@ -211,7 +212,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target)
global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + "_without_compile"
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target)
elif target.kind.name == "c":
......
......@@ -49,7 +49,7 @@ _Ret = TypeVar("_Ret")
def compile(
func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
......@@ -64,7 +64,7 @@ def compile(
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional
......@@ -118,7 +118,7 @@ def compile(
def par_compile(
funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
......@@ -135,7 +135,7 @@ def par_compile(
The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional
......@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
"""
out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
target: str | Target
target_host: str | Target
verbose: bool
......@@ -424,7 +424,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
return kernel
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
@overload
......@@ -473,7 +473,7 @@ def jit( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Backend for kernel execution and argument passing. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose : bool, optional
......
......@@ -4,3 +4,4 @@ from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401
from .torch import MetalKernelAdapter # noqa: F401
from .cutedsl import CuTeDSLKernelAdapter # noqa: F401
"""CuTeDSL Backend for TileLang.
This module provides runtime compilation support using NVIDIA's CuTeDSL API.
"""
__all__ = [
"CuTeDSLKernelAdapter",
"TLCuTeDSLSourceWrapper",
"CuTeDSLLibraryGenerator",
"check_cutedsl_available",
]
from .checks import check_cutedsl_available # noqa: F401
from .adapter import CuTeDSLKernelAdapter # noqa: F401
from .wrapper import TLCuTeDSLSourceWrapper # noqa: F401
from .libgen import CuTeDSLLibraryGenerator # noqa: F401
This diff is collapsed.
from __future__ import annotations
import re
from importlib import metadata as _importlib_metadata
from importlib.util import find_spec as _find_spec
import os
_CUTEDSL_PUBLIC_DIST = "nvidia-cutlass-dsl"
_CUTEDSL_MIN_VERSION = (4, 3, 1)
_VERSION_TRIPLE_RE = re.compile(r"(\d+)\.(\d+)\.(\d+)")
def _parse_version_triple(version_str: str) -> tuple[int, int, int] | None:
"""Parse a best-effort (major, minor, patch) triple from a version string.
We intentionally avoid importing heavy/optional version parsers. For our
minimum requirement (>= 4.3.1), a numeric triple comparison is sufficient.
"""
m = _VERSION_TRIPLE_RE.search(version_str)
if not m:
return None
return int(m.group(1)), int(m.group(2)), int(m.group(3))
def _min_version_str() -> str:
return ".".join(map(str, _CUTEDSL_MIN_VERSION))
def _requirement_spec() -> str:
return f"{_CUTEDSL_PUBLIC_DIST}>={_min_version_str()}"
def check_cutedsl_available() -> None:
"""Fail fast if the CuTeDSL backend cannot be used in this Python environment.
Policy:
- If the public distribution `nvidia-cutlass-dsl` is installed, require version >= a minimum supported version.
- Regardless of distribution metadata, require that `cutlass.cute` is importable.
This intentionally does not mention or special-case any internal distributions.
"""
# 1) Version gate (only when the public dist metadata is present)
try:
dist_version = _importlib_metadata.version(_CUTEDSL_PUBLIC_DIST)
except _importlib_metadata.PackageNotFoundError:
dist_version = None
except Exception:
# Metadata is best-effort; don't block internal/nonstandard installs here.
dist_version = None
if dist_version is not None:
parsed = _parse_version_triple(dist_version)
if parsed is None or parsed < _CUTEDSL_MIN_VERSION:
req = _requirement_spec()
raise ImportError(
f"CuTeDSL backend requires `{req}`, but found version `{dist_version}`. Please run: `pip install -U '{req}'`."
)
# 2) Capability probe: keep it cheap.
# Importing cutlass/cute can be expensive and defeats our lazy-import design,
# especially on cache hits. We only require that the module is importable.
cutlass_spec = _find_spec("cutlass")
if cutlass_spec is None:
req = _requirement_spec()
raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).")
# Avoid find_spec("cutlass.cute") which can be surprisingly expensive.
# Instead, check for a 'cute' submodule/package under cutlass's search locations.
locs = getattr(cutlass_spec, "submodule_search_locations", None)
has_cute = False
if locs:
for base in locs:
if os.path.isdir(os.path.join(base, "cute")) or os.path.isfile(os.path.join(base, "cute.py")):
has_cute = True
break
if not has_cute:
req = _requirement_spec()
raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).")
"""CuTeDSL Library Generator for TileLang.
This module provides library generation functionality for the CuTeDSL backend.
"""
from __future__ import annotations
import importlib.util
import os
import tempfile
import subprocess
from tvm.target import Target
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cutedsl_target
class CuTeDSLLibraryGenerator(LibraryGenerator):
host_func: str | None = None
tma_cpp_init_code: str | None = None
tma_lib_name: str | None = None
launcher_cpp_code: str | None = None
launcher_lib_name: str | None = None
pymodule = None
def __init__(self, target: Target, verbose: bool = False):
super().__init__(target, verbose)
@staticmethod
def import_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
self.host_func = host_func
def update_tma_cpp_init_code(self, tma_cpp_init_code: str):
self.tma_cpp_init_code = tma_cpp_init_code
def update_tma_lib_name(self, tma_lib_name: str):
self.tma_lib_name = tma_lib_name
def update_launcher_cpp_code(self, launcher_cpp_code: str):
self.launcher_cpp_code = launcher_cpp_code
def update_launcher_lib_name(self, launcher_lib_name: str):
self.launcher_lib_name = launcher_lib_name
def load_lib(self, lib_path: str | None = None):
if lib_path is None:
if self.libpath is None:
raise RuntimeError("CuTeDSLLibraryGenerator.libpath is not set; call compile_lib() first or pass lib_path explicitly.")
lib_path = self.libpath
self.pymodule = self.import_from_file("kernel", lib_path)
def compile_lib(self, timeout: float = None):
if self.host_func is None:
raise RuntimeError("CuTeDSLLibraryGenerator.host_func is not set; call update_host_func() before compile_lib().")
target = self.target
if is_cutedsl_target(target):
# Use a dedicated temp directory per kernel so CuTeDSL artifacts (e.g. kept .cubin)
# never pollute user CWD, and are easy to locate alongside the generated module.
work_dir = tempfile.mkdtemp(prefix="tilelang_cutedsl_")
src_path = os.path.join(work_dir, "kernel.py")
with open(src_path, "w") as f:
# Note: lib_code (containing @cute.kernel definitions) is embedded
# inside host_func's _generate_cubin_if_needed function, so we only
# write host_func here. This ensures cute imports are lazy-loaded.
f.write(self.host_func)
# Compile C++ launcher library if needed
if self.launcher_cpp_code is not None:
with tempfile.NamedTemporaryFile(
mode="w",
suffix=".cpp",
delete=False,
) as launcher_src:
launcher_src.write(self.launcher_cpp_code)
launcher_src_path = launcher_src.name
# Generate launcher lib under the same directory as the source file
launcher_lib_path = os.path.join(os.path.dirname(src_path), self.launcher_lib_name)
# Get TVM FFI compiler flags using tvm_ffi.libinfo API
try:
import tvm_ffi.libinfo
include_paths = tvm_ffi.libinfo.include_paths()
tvm_cxxflags = [f"-I{path}" for path in include_paths]
lib_path = tvm_ffi.libinfo.find_libtvm_ffi()
lib_dir = os.path.dirname(lib_path)
tvm_ldflags = [f"-L{lib_dir}", "-ltvm_ffi"]
except (ImportError, RuntimeError):
# tvm_ffi unavailable or libinfo functions failed
tvm_cxxflags = []
tvm_ldflags = []
# Compile with nvcc (need CUDA driver API)
compile_cmd = [
"nvcc",
"-shared",
"-Xcompiler=-fPIC",
"-lcuda",
*tvm_cxxflags,
*tvm_ldflags,
"-o",
launcher_lib_path,
launcher_src_path,
]
result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True, timeout=timeout)
if result.returncode != 0:
raise RuntimeError(f"Failed to compile C++ launcher: {result.stderr}")
self.launcher_libpath = launcher_lib_path
self.launcher_libname = self.launcher_lib_name
self.srcpath = src_path
self.libpath = src_path
else:
raise ValueError(f"Unsupported target: {target}")
This diff is collapsed.
......@@ -76,7 +76,9 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(device_kernel_source)
wrapper_result = self.wrapper.wrap(device_kernel_source)
self.host_func = wrapper_result["host_func"]
self.function_names = wrapper_result["function_names"]
self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.device_kernel_source)
......
......@@ -273,7 +273,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
Casts are noise in generated Python code - Python is dynamically typed.
"""
return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True)
return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True, floor_div_op="//")
def create_dispatch_func(self, code, function_informations):
"""Generate Python dispatch function that launches multiple CUDA kernels.
......
......@@ -38,6 +38,53 @@ def match_declare_kernel(source: str, annotation: str = "__global__") -> int:
raise ValueError("No global kernel found in the source code")
def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int:
# Match decorator followed by function definition across lines
# \s+ allows any whitespace including newlines between decorator and def
pattern = r"@cute\.kernel\s+def\s+(\w+)"
matched = re.search(pattern, source, re.MULTILINE)
if matched:
# Find the position of the opening parenthesis after the function name
# matched.start(1) gives position of function name
func_name_pos = matched.start(1)
# Find the '(' after function name
paren_pos = source.find("(", func_name_pos)
if paren_pos != -1:
return paren_pos
raise ValueError("No global kernel found in the source code")
def extract_python_func_declaration(source: str, func_name: str) -> str:
"""Extract the full Python function declaration from decorator to colon.
Args:
source: Source code containing the function
func_name: Name of the function to extract (can include '(' suffix)
Returns:
The function declaration from 'def' to ':', including parameters
Example:
For code:
@cute.kernel
def kernel(arg1: cute.Tensor, arg2: int):
...
Returns: "def kernel(arg1: cute.Tensor, arg2: int)"
"""
# Remove '(' suffix if present
if func_name.endswith("("):
func_name = func_name[:-1]
# Match from def to the closing ) followed by :
# This handles multi-line function signatures
pattern = rf"def\s+{re.escape(func_name)}\s*\([^)]*\)"
matched = re.search(pattern, source, re.DOTALL)
if matched:
return matched.group(0)
raise ValueError(f"No function declaration found for {func_name}")
def match_declare_kernel_cpu(source: str, annotation: str = "int32_t") -> int:
pattern = r"int32_t\s+\w+"
for line in source.split("\n"):
......@@ -64,6 +111,10 @@ def is_metal_target(target: Target) -> bool:
return target.kind.name == "metal"
def is_cutedsl_target(target: Target) -> bool:
return target.kind.name == "cuda" and "cutedsl" in target.keys
def get_annotated_mod(
func_or_mod: tir.PrimFunc | tvm.IRModule,
target: str | Target = "auto",
......@@ -102,7 +153,9 @@ def get_annotated_mod(
return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False) -> str:
def pythonic_expr(
expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False, floor_div_op: str = "/"
) -> str:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
......@@ -110,6 +163,10 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
expr: The TVM PrimExpr to convert.
dtype_map: A dictionary mapping data types to their string representations.
ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast.
floor_div_op: Operator to use for tvm.tir.FloorDiv. Default '/' preserves prior
behavior (suitable for generating C/C++ expressions). For generating
Python code where integer division is required (e.g. grid/block),
pass '//' explicitly.
Returns:
A string representation of the expression.
"""
......@@ -180,7 +237,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
):
op_map = {
tvm.tir.Mul: "*",
tvm.tir.FloorDiv: "/",
tvm.tir.FloorDiv: floor_div_op,
tvm.tir.Add: "+",
tvm.tir.Sub: "-",
tvm.tir.FloorMod: "%",
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment