Unverified Commit 7248a810 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

feat(cutedsl): add CuTeDSL backend (#1421)



* feat: CuTeDSL backend

* fix: clang-tidy

* fix: clang-format

* fix: ci

* fix: revert example gemm fp8

* fix: remove duplicate code

* fix: switch-case

* fix: fp16 silence

* fix: TVM IR print

* fix: useless tir

* fix: clang-format

* fix: remove tilelang/contrib/cutedsl/.gitignore

* fix: use hexfloat

* fix: gsym guard

* fix: unknown storage sync type

* fix: string literal

* fix: add args guard

* fix: name hint dedup

* fix: better find_kernel_by_pattern

* fix: set libpath for from_database path

* fix: guard buffer.strides

* fix: from guard

* fix: eviction guard

* fix: use thread local tma descs

* fix: ruff

* fix: drop tma_init_cpp

* fix: exc_info

* fix: negative unmatch early return

* fix: rename postproc func and add test

* fix: handle fast math according to pass config

* fix: dyn_sym parse

* fix: wrap_forward

* fix: use tvm_ffi.libinfo instead of cli

* fix: keep signature

* fix: C++ string safety

* fix: mark tma_store_add as unsupported

* fix: tvm version

* resolve ldsm and cpasync issues.

* fix: minor fixes

* fix: parse signature using ast

* fix: guard global_addr

* fix: create tempfile only when necessary

* fix: use logger.execption for exceptions

* fix: guard lib_path and host_func

* fix: remove tma_cpp_init and add timeout for cpp compile

* add timeout for mbarrier_wait.

* fix: _load_kernel_from_disk signature

* resolve codegen issues.

* fix: logger.exception

* add comment for div_by=1

* merge

* fix: reserve cutlass,cute,tl

* fix: guard tma_store

* fix: allow int64 offset in make_tensor_at_offset

* fix: guard barrier

* fix: add comments for div_by=16

* fix: div_by=1 issue

* delete div_by when offset is 0

* use tl.make_tensor when offset is 0

* fix: explicitly check cutedsl target

* fix: use param.torch_dtype()

---------
Co-authored-by: default avataryuxic <yuxic@nvidia.com>
Co-authored-by: default avatarYong <yong@local>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent a6f59f31
"""
Reduce operations for CuTeDSL backend.
Based on tl_templates/cuda/reduce.h
"""
from __future__ import annotations
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Int32, Float32
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import nvvm
from cutlass.cute.arch.nvvm_wrappers import shuffle_sync_op
@dsl_user_op
def min(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32:
return Float32(
nvvm.fmin(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
@dsl_user_op
def max(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32:
return Float32(
nvvm.fmax(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
class SumOp:
"""Sum reduction operator"""
@staticmethod
def __call__(x, y):
return x + y
class MaxOp:
"""Max reduction operator"""
@staticmethod
def __call__(x, y):
return max(x, y)
class MinOp:
"""Min reduction operator"""
@staticmethod
def __call__(x, y):
# Use cutlass.min which is JIT-friendly
return min(x, y)
class BitAndOp:
"""Bitwise AND reduction operator"""
@staticmethod
def __call__(x, y):
return x & y
class BitOrOp:
"""Bitwise OR reduction operator"""
@staticmethod
def __call__(x, y):
return x | y
class BitXorOp:
"""Bitwise XOR reduction operator"""
@staticmethod
def __call__(x, y):
return x ^ y
def bar_sync(barrier_id, number_of_threads):
cute.arch.barrier(barrier_id=barrier_id, number_of_threads=number_of_threads)
def bar_sync_ptx(barrier_id, number_of_threads):
from cutlass._mlir.dialects import llvm
llvm.inline_asm(
None,
[Int32(barrier_id).ir_value(), Int32(number_of_threads).ir_value()],
"bar.sync $0, $1;",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def AllReduce(reducer, threads, scale, thread_offset, all_threads=None):
"""
AllReduce operation implementing warp/block-level reduction.
Based on tl::AllReduce from reduce.h
Args:
reducer: Reducer operator class (SumOp, MaxOp, etc.)
threads: Number of threads participating in reduction
scale: Reduction scale factor
thread_offset: Thread ID offset
all_threads: Total number of threads in block
Returns:
A callable object with run() and run_hopper() methods
"""
class AllReduceInstance:
def __init__(self, reducer, threads, scale, thread_offset: cutlass.Constexpr[int], all_threads: cutlass.Constexpr[int]):
self.reducer = reducer
self.threads = threads
self.scale = scale
self.thread_offset = thread_offset
self.all_threads = all_threads if all_threads is not None else threads
def run(self, x, red_buf: cute.Pointer = None):
"""
Perform all-reduce across threads.
Based on tl::AllReduce<...>::run from reduce.h
"""
offset = self.threads // 2
if offset >= 32:
# Use shared memory for large thread counts
cute.arch.sync_threads()
tidx, _, _ = cute.arch.thread_idx()
cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x
cute.arch.sync_threads()
x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0])
else:
# Use warp shuffle for small thread counts
# Use the pre-existing shuffle_sync_op with butterfly (XOR) mode
other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly)
x = self.reducer()(x, other)
return (
x
if offset == self.scale
else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf)
)
def run_hopper(self, x, red_buf: cute.Pointer = None):
"""
Perform all-reduce on Hopper architecture using bar.sync.
Based on tl::AllReduce<...>::run_hopper from reduce.h
"""
offset = self.threads // 2
tidx, _, _ = cute.arch.thread_idx()
if offset >= 32:
# Use inlined asm for bar.sync to avoid instruction reordering
bar_sync_ptx(1, self.all_threads)
cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x
bar_sync_ptx(2, self.all_threads)
x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0])
else:
# Use warp shuffle for small thread counts
# Use the pre-existing shuffle_sync_op with butterfly (XOR) mode
other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly)
x = self.reducer()(x, other)
return (
x
if offset == self.scale
else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run_hopper(x, red_buf)
)
return AllReduceInstance(reducer, threads, scale, thread_offset, all_threads)
import cutlass.cute as cute
from cutlass.cute.typing import Constexpr
from dataclasses import dataclass
@dataclass(frozen=True)
class dim3:
x: int
y: int
z: int
def ThreadIdx() -> dim3:
return dim3(*cute.arch.thread_idx())
def BlockIdx() -> dim3:
return dim3(*cute.arch.block_idx())
def GridDim() -> dim3:
return dim3(*cute.arch.grid_dim())
@cute.jit
def rasterization2DRow(panel_width: Constexpr[int]) -> dim3:
blockIdx = BlockIdx()
gridDim = GridDim()
block_idx = blockIdx.x + blockIdx.y * gridDim.x
grid_size = gridDim.x * gridDim.y
panel_size = panel_width * gridDim.x
panel_offset = block_idx % panel_size
panel_idx = block_idx // panel_size
total_panel = cute.ceil_div(grid_size, panel_size)
stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.x
col_idx = (gridDim.x - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride)
row_idx = panel_offset % stride + panel_idx * panel_width
return dim3(col_idx, row_idx, blockIdx.z)
@cute.jit
def rasterization2DColumn(panel_width: Constexpr[int]) -> dim3:
blockIdx = BlockIdx()
gridDim = GridDim()
block_idx = blockIdx.x + blockIdx.y * gridDim.x
grid_size = gridDim.x * gridDim.y
panel_size = panel_width * gridDim.y
panel_offset = block_idx % panel_size
panel_idx = block_idx // panel_size
total_panel = cute.ceil_div(grid_size, panel_size)
stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.y
row_idx = (gridDim.y - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride)
col_idx = panel_offset % stride + panel_idx * panel_width
return dim3(col_idx, row_idx, blockIdx.z)
...@@ -197,7 +197,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: ...@@ -197,7 +197,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target) global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda")
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
elif target.kind.name == "hip": elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
else: else:
...@@ -211,7 +212,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> ...@@ -211,7 +212,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tilelang.transform.LowerIntrin()(device_mod) device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod) device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target) global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + "_without_compile"
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
elif target.kind.name == "hip": elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target)
elif target.kind.name == "c": elif target.kind.name == "c":
......
...@@ -49,7 +49,7 @@ _Ret = TypeVar("_Ret") ...@@ -49,7 +49,7 @@ _Ret = TypeVar("_Ret")
def compile( def compile(
func: PrimFunc[_KP, _T] = None, func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
...@@ -64,7 +64,7 @@ def compile( ...@@ -64,7 +64,7 @@ def compile(
The TileLang TIR function to compile and wrap. The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython). default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional target : Union[str, Target], optional
...@@ -118,7 +118,7 @@ def compile( ...@@ -118,7 +118,7 @@ def compile(
def par_compile( def par_compile(
funcs: Iterable[PrimFunc[_KP, _T]], funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
...@@ -135,7 +135,7 @@ def par_compile( ...@@ -135,7 +135,7 @@ def par_compile(
The TileLang TIR functions to compile and wrap. The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython). default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional target : Union[str, Target], optional
...@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
""" """
out_idx: list[int] | int | None out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
target: str | Target target: str | Target
target_host: str | Target target_host: str | Target
verbose: bool verbose: bool
...@@ -424,7 +424,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): ...@@ -424,7 +424,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
return kernel return kernel
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
@overload @overload
...@@ -473,7 +473,7 @@ def jit( # This is the new public interface ...@@ -473,7 +473,7 @@ def jit( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None. Target host for cross-compilation. Defaults to None.
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
Backend for kernel execution and argument passing. Use "auto" to pick a sensible Backend for kernel execution and argument passing. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython). default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose : bool, optional verbose : bool, optional
......
...@@ -4,3 +4,4 @@ from .ctypes import CtypesKernelAdapter # noqa: F401 ...@@ -4,3 +4,4 @@ from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401
from .torch import MetalKernelAdapter # noqa: F401 from .torch import MetalKernelAdapter # noqa: F401
from .cutedsl import CuTeDSLKernelAdapter # noqa: F401
"""CuTeDSL Backend for TileLang.
This module provides runtime compilation support using NVIDIA's CuTeDSL API.
"""
__all__ = [
"CuTeDSLKernelAdapter",
"TLCuTeDSLSourceWrapper",
"CuTeDSLLibraryGenerator",
"check_cutedsl_available",
]
from .checks import check_cutedsl_available # noqa: F401
from .adapter import CuTeDSLKernelAdapter # noqa: F401
from .wrapper import TLCuTeDSLSourceWrapper # noqa: F401
from .libgen import CuTeDSLLibraryGenerator # noqa: F401
This diff is collapsed.
from __future__ import annotations
import re
from importlib import metadata as _importlib_metadata
from importlib.util import find_spec as _find_spec
import os
_CUTEDSL_PUBLIC_DIST = "nvidia-cutlass-dsl"
_CUTEDSL_MIN_VERSION = (4, 3, 1)
_VERSION_TRIPLE_RE = re.compile(r"(\d+)\.(\d+)\.(\d+)")
def _parse_version_triple(version_str: str) -> tuple[int, int, int] | None:
"""Parse a best-effort (major, minor, patch) triple from a version string.
We intentionally avoid importing heavy/optional version parsers. For our
minimum requirement (>= 4.3.1), a numeric triple comparison is sufficient.
"""
m = _VERSION_TRIPLE_RE.search(version_str)
if not m:
return None
return int(m.group(1)), int(m.group(2)), int(m.group(3))
def _min_version_str() -> str:
return ".".join(map(str, _CUTEDSL_MIN_VERSION))
def _requirement_spec() -> str:
return f"{_CUTEDSL_PUBLIC_DIST}>={_min_version_str()}"
def check_cutedsl_available() -> None:
"""Fail fast if the CuTeDSL backend cannot be used in this Python environment.
Policy:
- If the public distribution `nvidia-cutlass-dsl` is installed, require version >= a minimum supported version.
- Regardless of distribution metadata, require that `cutlass.cute` is importable.
This intentionally does not mention or special-case any internal distributions.
"""
# 1) Version gate (only when the public dist metadata is present)
try:
dist_version = _importlib_metadata.version(_CUTEDSL_PUBLIC_DIST)
except _importlib_metadata.PackageNotFoundError:
dist_version = None
except Exception:
# Metadata is best-effort; don't block internal/nonstandard installs here.
dist_version = None
if dist_version is not None:
parsed = _parse_version_triple(dist_version)
if parsed is None or parsed < _CUTEDSL_MIN_VERSION:
req = _requirement_spec()
raise ImportError(
f"CuTeDSL backend requires `{req}`, but found version `{dist_version}`. Please run: `pip install -U '{req}'`."
)
# 2) Capability probe: keep it cheap.
# Importing cutlass/cute can be expensive and defeats our lazy-import design,
# especially on cache hits. We only require that the module is importable.
cutlass_spec = _find_spec("cutlass")
if cutlass_spec is None:
req = _requirement_spec()
raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).")
# Avoid find_spec("cutlass.cute") which can be surprisingly expensive.
# Instead, check for a 'cute' submodule/package under cutlass's search locations.
locs = getattr(cutlass_spec, "submodule_search_locations", None)
has_cute = False
if locs:
for base in locs:
if os.path.isdir(os.path.join(base, "cute")) or os.path.isfile(os.path.join(base, "cute.py")):
has_cute = True
break
if not has_cute:
req = _requirement_spec()
raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).")
"""CuTeDSL Library Generator for TileLang.
This module provides library generation functionality for the CuTeDSL backend.
"""
from __future__ import annotations
import importlib.util
import os
import tempfile
import subprocess
from tvm.target import Target
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cutedsl_target
class CuTeDSLLibraryGenerator(LibraryGenerator):
host_func: str | None = None
tma_cpp_init_code: str | None = None
tma_lib_name: str | None = None
launcher_cpp_code: str | None = None
launcher_lib_name: str | None = None
pymodule = None
def __init__(self, target: Target, verbose: bool = False):
super().__init__(target, verbose)
@staticmethod
def import_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
self.host_func = host_func
def update_tma_cpp_init_code(self, tma_cpp_init_code: str):
self.tma_cpp_init_code = tma_cpp_init_code
def update_tma_lib_name(self, tma_lib_name: str):
self.tma_lib_name = tma_lib_name
def update_launcher_cpp_code(self, launcher_cpp_code: str):
self.launcher_cpp_code = launcher_cpp_code
def update_launcher_lib_name(self, launcher_lib_name: str):
self.launcher_lib_name = launcher_lib_name
def load_lib(self, lib_path: str | None = None):
if lib_path is None:
if self.libpath is None:
raise RuntimeError("CuTeDSLLibraryGenerator.libpath is not set; call compile_lib() first or pass lib_path explicitly.")
lib_path = self.libpath
self.pymodule = self.import_from_file("kernel", lib_path)
def compile_lib(self, timeout: float = None):
if self.host_func is None:
raise RuntimeError("CuTeDSLLibraryGenerator.host_func is not set; call update_host_func() before compile_lib().")
target = self.target
if is_cutedsl_target(target):
# Use a dedicated temp directory per kernel so CuTeDSL artifacts (e.g. kept .cubin)
# never pollute user CWD, and are easy to locate alongside the generated module.
work_dir = tempfile.mkdtemp(prefix="tilelang_cutedsl_")
src_path = os.path.join(work_dir, "kernel.py")
with open(src_path, "w") as f:
# Note: lib_code (containing @cute.kernel definitions) is embedded
# inside host_func's _generate_cubin_if_needed function, so we only
# write host_func here. This ensures cute imports are lazy-loaded.
f.write(self.host_func)
# Compile C++ launcher library if needed
if self.launcher_cpp_code is not None:
with tempfile.NamedTemporaryFile(
mode="w",
suffix=".cpp",
delete=False,
) as launcher_src:
launcher_src.write(self.launcher_cpp_code)
launcher_src_path = launcher_src.name
# Generate launcher lib under the same directory as the source file
launcher_lib_path = os.path.join(os.path.dirname(src_path), self.launcher_lib_name)
# Get TVM FFI compiler flags using tvm_ffi.libinfo API
try:
import tvm_ffi.libinfo
include_paths = tvm_ffi.libinfo.include_paths()
tvm_cxxflags = [f"-I{path}" for path in include_paths]
lib_path = tvm_ffi.libinfo.find_libtvm_ffi()
lib_dir = os.path.dirname(lib_path)
tvm_ldflags = [f"-L{lib_dir}", "-ltvm_ffi"]
except (ImportError, RuntimeError):
# tvm_ffi unavailable or libinfo functions failed
tvm_cxxflags = []
tvm_ldflags = []
# Compile with nvcc (need CUDA driver API)
compile_cmd = [
"nvcc",
"-shared",
"-Xcompiler=-fPIC",
"-lcuda",
*tvm_cxxflags,
*tvm_ldflags,
"-o",
launcher_lib_path,
launcher_src_path,
]
result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True, timeout=timeout)
if result.returncode != 0:
raise RuntimeError(f"Failed to compile C++ launcher: {result.stderr}")
self.launcher_libpath = launcher_lib_path
self.launcher_libname = self.launcher_lib_name
self.srcpath = src_path
self.libpath = src_path
else:
raise ValueError(f"Unsupported target: {target}")
This diff is collapsed.
...@@ -76,7 +76,9 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -76,7 +76,9 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod) self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod) self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(device_kernel_source) wrapper_result = self.wrapper.wrap(device_kernel_source)
self.host_func = wrapper_result["host_func"]
self.function_names = wrapper_result["function_names"]
self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose) self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.device_kernel_source) self.lib_generator.update_lib_code(self.device_kernel_source)
......
...@@ -273,7 +273,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -273,7 +273,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
Casts are noise in generated Python code - Python is dynamically typed. Casts are noise in generated Python code - Python is dynamically typed.
""" """
return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True) return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True, floor_div_op="//")
def create_dispatch_func(self, code, function_informations): def create_dispatch_func(self, code, function_informations):
"""Generate Python dispatch function that launches multiple CUDA kernels. """Generate Python dispatch function that launches multiple CUDA kernels.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -15,6 +15,7 @@ SUPPORTED_TARGETS: dict[str, str] = { ...@@ -15,6 +15,7 @@ SUPPORTED_TARGETS: dict[str, str] = {
"llvm": "LLVM CPU target (accepts standard TVM LLVM options).", "llvm": "LLVM CPU target (accepts standard TVM LLVM options).",
"webgpu": "WebGPU target for browser/WebGPU runtimes.", "webgpu": "WebGPU target for browser/WebGPU runtimes.",
"c": "C source backend.", "c": "C source backend.",
"cutedsl": "CuTe DSL GPU target.",
} }
...@@ -95,6 +96,14 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj ...@@ -95,6 +96,14 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj
return_var = "metal" return_var = "metal"
else: else:
raise ValueError("No CUDA or HIP or MPS available on this system.") raise ValueError("No CUDA or HIP or MPS available on this system.")
elif isinstance(target, str) and target.startswith("cutedsl"):
cuda_target_str = target.replace("cutedsl", "cuda", 1)
temp_target = Target(cuda_target_str)
target_dict = dict(temp_target.export())
target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"]
return_var = Target(target_dict)
else: else:
# Validate the target if it's not "auto" # Validate the target if it's not "auto"
if isinstance(target, Target): if isinstance(target, Target):
...@@ -115,6 +124,8 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj ...@@ -115,6 +124,8 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj
else: else:
raise AssertionError(f"Target {target} is not supported") raise AssertionError(f"Target {target} is not supported")
if isinstance(return_var, Target):
return return_var
if return_object: if return_object:
if isinstance(return_var, Target): if isinstance(return_var, Target):
return return_var return return_var
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment