Commit 17f7394f authored by Gabriel Wu's avatar Gabriel Wu Committed by LeiWang1999
Browse files

[Enhancement] Add nvrtc execution backend (#461)



* [wip] feat: add nvrtc backend

* [wip] fix: handle out_idx

* [wip] refactor: move lib logic to libgen

* feat: cache for nvrtc backend

* fmt: run format

* fix: handle cuda bindings import error

* fix: handle cuda bindings import error

* fix: handle cuda bindings import error

* fix: handle cuda bindings import error

* fix: get kernel source

* refactor: speedup pyimport

* Improve error handling for missing cuda-python dependency in nvrtc backend. Raise ImportError with detailed installation instructions instead of logging a warning.

* Enhance nvrtc backend error handling by introducing a flag to check for cuda-python availability. Raise ImportError with detailed installation instructions during initialization if the nvrtc backend is unavailable, improving user experience and clarity.

* Update README.md to include recent NVRTC Backend addition, highlighting reduced compilation time for CUDA templates.

* fix tl_templates

* ensure CUDA context

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 88c622c9
......@@ -13,6 +13,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />
## Latest News
- 05/06/2025 ✨: Added [NVRTC Backend](https://github.com/tile-ai/tilelang/pull/461) to significantly reduce compilation time for cute templates!
- 14/04/2025 🚀: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See [example_mla_amd](./examples/deepseek_mla/amd/README.md) for details.
- 03/03/2025 🚀: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see [example_mla_decode.py](./examples/deepseek_mla/example_mla_decode.py))! We also provide [documentation](./examples/deepseek_mla/README.md) explaining how TileLang achieves this.
- 02/15/2025 ✨: Added WebGPU Codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)!
......
#pragma once
#ifndef __CUDACC_RTC__
#include <cuda_runtime.h>
#endif
#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>
......
......@@ -25,7 +25,7 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
......@@ -34,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N));
} else {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
#else
......@@ -52,7 +52,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
......@@ -61,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
} else {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
......
#pragma once
#ifndef __CUDACC_RTC__
#include <cuda.h>
#endif
#include "common.h"
......
......@@ -2,7 +2,10 @@
#include "./cuda_fp8.h"
#include "common.h"
#include <stdio.h>
#ifndef __CUDACC_RTC__
#include <cstdio>
#endif
// Template declaration for device-side debug printing (variable only)
template <typename T> __device__ void debug_print_var(const char *msg, T var);
......
#pragma once
#include "common.h"
#include "cuda_fp8.h"
#include <cute/algorithm/clear.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/underscore.hpp>
#include "common.h"
#include "cuda_fp8.h"
namespace cute {
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
......
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef __CUDACC_RTC__
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
using cuuint64_t = unsigned long long;
#ifndef CU_TENSOR_MAP_NUM_QWORDS
#define CU_TENSOR_MAP_NUM_QWORDS 16
struct CUtensorMap_st {
#if defined(__cplusplus) && (__cplusplus >= 201103L)
alignas(64)
#elif __STDC_VERSION__ >= 201112L
_Alignas(64)
#endif
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
};
using CUtensorMap = CUtensorMap_st;
#endif
namespace std {
template <class T, T v> struct integral_constant {
static constexpr T value = v;
using value_type = T;
using type = integral_constant;
__device__ constexpr operator value_type() const noexcept { return value; }
__device__ constexpr value_type operator()() const noexcept { return value; }
};
using false_type = integral_constant<bool, false>;
using true_type = integral_constant<bool, true>;
template <class T, class U> struct is_same : false_type {};
template <class T> struct is_same<T, T> : true_type {};
template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value;
namespace index_sequence_impl {
// Based on https://stackoverflow.com/a/32223343/11717224
template <size_t... Ints> struct index_sequence {
using type = index_sequence;
using value_type = size_t;
static constexpr size_t size() noexcept { return sizeof...(Ints); }
};
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
template <size_t... I1, size_t... I2>
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
template <size_t N>
struct make_index_sequence
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
typename make_index_sequence<N - N / 2>::type> {};
template <> struct make_index_sequence<0> : index_sequence<> {};
template <> struct make_index_sequence<1> : index_sequence<0> {};
} // namespace index_sequence_impl
template <size_t... Ns>
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
template <size_t N>
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
template <typename T> constexpr T min(T a, T b) { return a < b ? a : b; }
template <typename T> constexpr T max(T a, T b) { return a > b ? a : b; }
template <bool B, class T, class F> struct conditional {
using type = T;
};
template <class T, class F> struct conditional<false, T, F> {
using type = F;
};
template <bool B, class T, class F>
using conditional_t = typename conditional<B, T, F>::type;
template <bool B, class T = void> struct enable_if {};
template <class T> struct enable_if<true, T> {
using type = T;
};
} // namespace std
#endif
\ No newline at end of file
......@@ -18,12 +18,12 @@ def cached(
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Optional[Literal["dlpack", "ctypes", "cython"]] = "cython",
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None,
) -> JITKernel:
"""
Caches and reuses compiled kerne(ls (using KernelCache class).
Caches and reuses compiled kernels (using KernelCache class).
"""
return _kernel_cache_instance.cached(
func,
......
"""The cache utils with class and database persistence - KernelCache Class"""
import os
import json
import logging
import os
import shutil
from pathlib import Path
import threading
from hashlib import sha256
from typing import Callable, List, Literal, Union, Optional
from pathlib import Path
from typing import Callable, List, Literal, Optional, Union
import cloudpickle
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from tilelang.engine.param import KernelParam
import threading
import cloudpickle
import logging
from tilelang.engine.param import KernelParam
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.jit import JITKernel
from tilelang.version import __version__
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
KERNEL_LIB_PATH = "kernel_lib.so"
KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py"
PARAMS_PATH = "params.pkl"
......@@ -36,6 +39,7 @@ class KernelCache:
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython"
cache_dir: Path = Path(TILELANG_CACHE_DIR)
......@@ -66,7 +70,7 @@ class KernelCache:
self,
func: Callable,
out_idx: List[int],
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
......@@ -86,6 +90,7 @@ class KernelCache:
Returns:
str: SHA256 hash key for the kernel configuration.
"""
self.execution_backend = execution_backend
func_binary = cloudpickle.dumps(func.script())
key_data = {
"version": __version__,
......@@ -99,8 +104,10 @@ class KernelCache:
"execution_backend": execution_backend,
"pass_configs": pass_configs,
}
key_string = json.dumps(key_data, sort_keys=True) # Sort keys to ensure consistency
return sha256(key_string.encode()).hexdigest() # Use SHA256 to generate hash key
# Sort keys to ensure consistency
key_string = json.dumps(key_data, sort_keys=True)
# Use SHA256 to generate hash key
return sha256(key_string.encode()).hexdigest()
def cached(
self,
......@@ -109,7 +116,7 @@ class KernelCache:
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
) -> JITKernel:
......@@ -251,9 +258,15 @@ class KernelCache:
# Save kernel library
try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
if self.execution_backend == "nvrtc":
kernel_lib_path = os.path.join(cache_path, KERNEL_CUBIN_PATH)
else:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
shutil.copy(src_lib_path, kernel_lib_path)
if self.execution_backend == "nvrtc":
shutil.copy(
src_lib_path.replace(".cubin", ".py"), os.path.join(cache_path, KERNEL_PY_PATH))
except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}")
......@@ -271,7 +284,7 @@ class KernelCache:
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
pass_configs: dict = None,
func: Callable = None,
) -> JITKernel:
......@@ -304,7 +317,10 @@ class KernelCache:
except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
if self.execution_backend == "nvrtc":
kernel_lib_path = os.path.join(cache_path, KERNEL_CUBIN_PATH)
else:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
# Load kernel parameters
try:
......@@ -332,7 +348,7 @@ class KernelCache:
def _clear_disk_cache(self):
"""
Removes all cached kernels from disk.
Note:
This operation will delete the entire cache directory and recreate it empty.
Use with caution as this operation cannot be undone.
......@@ -340,6 +356,7 @@ class KernelCache:
try:
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir) # Delete entire cache directory
os.makedirs(self.cache_dir, exist_ok=True) # Re-create cache directory
# Re-create cache directory
os.makedirs(self.cache_dir, exist_ok=True)
except Exception as e:
self.logger.error(f"Error clearing disk cache: {e}")
import cuda.bindings.nvrtc as nvrtc
from typing import Literal, Union, List, Optional, Tuple
from tvm.target import Target
from .nvcc import get_target_compute_version
def get_nvrtc_version() -> Tuple[int, int]:
result, major, minor = nvrtc.nvrtcVersion()
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get NVRTC version: {result}"
return (major, minor)
def compile_cuda(code: str,
target_format: Literal["ptx", "cubin"] = "ptx",
arch: Optional[int] = None,
options: Optional[Union[str, List[str]]] = None,
verbose: bool = False) -> bytearray:
"""Compile cuda code with NVRTC.
Parameters
----------
code : str
The cuda code.
target_format : Literal["ptx", "cubin"]
The target format of nvrtc compiler.
arch : Optional[int]
The cuda architecture code.
options : Optional[Union[str, List[str]]]
The additional options.
verbose : bool
Whether to print the verbose output.
Return
------
result_bytes : bytearray
The bytearray of the cubin or ptx code.
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "80", "90", "90a", etc.
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split("."))
arch = int(compute_version)
prefix = "compute" if target_format == "ptx" else "sm"
suffix = "a" if arch >= 90 else ""
arch_option = f"--gpu-architecture={prefix}_{arch}{suffix}"
file_name = "tvm_kernels"
if target_format not in ["cubin", "ptx"]:
raise ValueError("target_format must be cubin or ptx")
final_options = ["-default-device"]
if get_nvrtc_version() >= (12, 8):
final_options += ["-pch"]
if arch is not None:
final_options += [arch_option]
if options:
if isinstance(options, str):
final_options += [options]
elif isinstance(options, list):
final_options += options
else:
raise ValueError("options must be str or list of str")
code = "#include <tl_templates/cuda/nvrtc_std.h>\n" + code
code_bytes = bytes(code, "utf-8")
result, program = nvrtc.nvrtcCreateProgram(code_bytes, bytes(file_name, "utf-8"), 0, [], [])
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to create program: {result}"
options_bytes = [bytes(flag, "utf-8") for flag in final_options]
compile_result = nvrtc.nvrtcCompileProgram(program, len(options_bytes), options_bytes)[0]
if compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
msg = f"{code}\n" \
f"Compilation error:\n"
if verbose:
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log size: {result}"
log_bytes = bytes(log_size)
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get program log: {result}"
msg += f"{log_bytes.decode('utf-8')}\n"
else:
msg += "Turn on verbose to see the full compilation log."
msg += f"Options: {' '.join(final_options)}\n"
raise RuntimeError(msg)
if target_format == "cubin":
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get CUBIN size: {result}"
result_bytes = bytes(cubin_size)
result = nvrtc.nvrtcGetCUBIN(program, result_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get CUBIN: {result}"
else:
result, ptx_size = nvrtc.nvrtcGetPTXSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX size: {result}"
result_bytes = bytes(ptx_size)
result = nvrtc.nvrtcGetPTX(program, result_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get PTX: {result}"
# Destroy handler
assert nvrtc.nvrtcDestroyProgram(
program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to destroy program: {result}"
return result_bytes
......@@ -32,7 +32,7 @@ logger = getLogger(__name__)
def compile(
func: PrimFunc = None,
out_idx: Union[List[int], int, None] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
......@@ -46,8 +46,8 @@ def compile(
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes"], optional
Execution backend to use for kernel execution (default: "dlpack").
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Execution backend to use for kernel execution (default: "cython").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional
......
......@@ -2,3 +2,4 @@ from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401
\ No newline at end of file
......@@ -177,7 +177,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
def _warp_forward_from_prebuild_lib(self,
def _wrap_forward_from_prebuild_lib(self,
*ins: List[torch.Tensor],
stream: Optional[int] = None):
"""High-level wrapper for kernel execution.
......@@ -241,7 +241,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> Callable:
"""Returns a PyTorch-compatible function wrapper for the kernel."""
return self._warp_forward_from_prebuild_lib
return self._wrap_forward_from_prebuild_lib
@property
def prim_func(self) -> tir.PrimFunc:
......
from typing import Optional
from .utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang import tvm as tvm
from tilelang.contrib.nvcc import get_target_compute_version, get_nvcc_compiler
from tvm.target import Target
import ctypes
import importlib
import logging
import os
import tempfile
import os.path as osp
import subprocess
import logging
from tilelang.env import TILELANG_TEMPLATE_PATH
import tempfile
from typing import Optional
from tvm.target import Target
from tilelang import tvm as tvm
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_compute_version
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
from tilelang.env import TILELANG_TEMPLATE_PATH
from .utils import is_cpu_target, is_cuda_target, is_hip_target
logger = logging.getLogger(__name__)
is_nvrtc_available = False
NVRTC_UNAVAILABLE_WARNING = "cuda-python is not available, nvrtc backend cannot be used. " \
"Please install cuda-python via `pip install cuda-python` " \
"if you want to use the nvrtc backend."
try:
import cuda.bindings.driver as cuda
from tilelang.contrib.nvrtc import compile_cuda
is_nvrtc_available = True
except ImportError:
pass
class LibraryGenerator(object):
srcpath: Optional[str] = None
......@@ -127,3 +143,88 @@ class LibraryGenerator(object):
def set_src_path(self, srcpath):
self.srcpath = srcpath
class PyLibraryGenerator(LibraryGenerator):
host_func: Optional[str] = None
culib = None
pymodule = None
def __init__(self, target: Target):
if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING)
super().__init__(target)
@staticmethod
def import_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
self.host_func = host_func
def load_lib(self, lib_path: Optional[str] = None):
if lib_path is None:
lib_path = self.libpath
pypath = lib_path.replace(".cubin", ".py")
self.pymodule = self.import_from_file("kernel", pypath)
# Ensure the context is valid
ctx = cuda.cuCtxGetCurrent()[1]
if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS:
import torch
torch.cuda.synchronize()
result, self.culib = cuda.cuLibraryLoadFromFile(
bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to load library: {lib_path}"
def compile_lib(self, timeout: float = None):
target = self.target
if is_cuda_target(target):
from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH)
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
libpath = src.name.replace(".cu", ".cubin")
project_root = osp.join(osp.dirname(__file__), "..", "..")
if CUTLASS_INCLUDE_DIR is None:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
else:
cutlass_path = CUTLASS_INCLUDE_DIR
if TILELANG_TEMPLATE_PATH is None:
tl_template_path = osp.abspath(osp.join(project_root, "src"))
else:
tl_template_path = TILELANG_TEMPLATE_PATH
cuda_home = "/usr/local/cuda" if CUDA_HOME is None else CUDA_HOME
cubin_bytes = compile_cuda(
self.lib_code,
target_format="cubin",
options=[f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"],
verbose=True)
with open(libpath, "wb") as f:
f.write(cubin_bytes)
src.write(self.lib_code)
src.flush()
self.srcpath = src.name
self.libpath = libpath
pypath = src.name.replace(".cu", ".py")
with open(pypath, "w") as f:
f.write(self.host_func)
else:
raise ValueError(f"Unsupported target: {target}")
def __del__(self):
if self.culib:
result = cuda.cuLibraryUnload(self.culib)[0]
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"Failed to unload library: {self.libpath}")
self.culib = None
from .adapter import NVRTCKernelAdapter # noqa: F401
import logging
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from tvm import tir
from tvm.target import Target
from tilelang import tvm as tvm
from tilelang.engine.param import KernelParam
from tilelang.jit.adapter.wrapper import TLPyWrapper
from tilelang.jit.adapter.libgen import PyLibraryGenerator
from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.target import determine_target
from ..base import BaseKernelAdapter
logger = logging.getLogger(__name__)
is_nvrtc_available = False
NVRTC_UNAVAILABLE_WARNING = "cuda-python is not available, nvrtc backend cannot be used. " \
"Please install cuda-python via `pip install cuda-python` " \
"if you want to use the nvrtc backend."
try:
import cuda.bindings.driver as cuda
is_nvrtc_available = True
except ImportError:
pass
class NVRTCKernelAdapter(BaseKernelAdapter):
pymodule = None
kernels = {}
def __init__(self,
params: List[KernelParam],
result_idx: List[int],
target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING)
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
self.ir_module = func_or_mod
# Cache parameter information during initialization
self.param_dtypes = [param.dtype for param in params]
self.param_shapes = []
for param in params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
# Keep tir.Var for dynamic dimensions
native_shape.append(dim)
else:
native_shape.append(dim)
self.param_shapes.append(native_shape)
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.target = Target.canon_target(determine_target(target))
self.verbose = verbose
self.wrapper = TLPyWrapper(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source)
self.lib_generator = PyLibraryGenerator(self.target)
self.lib_generator.update_lib_code(self.kernel_global_source)
self.lib_generator.update_host_func(self.host_func)
self.lib_generator.compile_lib()
self.lib_generator.load_lib()
self.libpath = self.lib_generator.libpath
self.pymodule = self.lib_generator.pymodule
culib = self.lib_generator.culib
for name in self.function_names:
result, self.kernels[name] = cuda.cuLibraryGetKernel(culib, bytes(name, "utf-8"))
assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to get kernel: {name}"
self._post_init()
@classmethod
def from_database(cls,
params: List[KernelParam],
result_idx: List[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
adapter.ir_module = func_or_mod
# Cache parameter information during initialization
adapter.param_dtypes = [param.dtype for param in params]
adapter.param_shapes = []
for param in params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
# Keep tir.Var for dynamic dimensions
native_shape.append(dim)
else:
native_shape.append(dim)
adapter.param_shapes.append(native_shape)
adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic()
adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose
adapter.lib_generator = PyLibraryGenerator(adapter.target)
adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.pymodule = adapter.lib_generator.pymodule
adapter.function_names = adapter.pymodule._function_names
culib = adapter.lib_generator.culib
for name in adapter.function_names:
result, adapter.kernels[name] = cuda.cuLibraryGetKernel(culib, bytes(name, "utf-8"))
assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to get kernel: {name}"
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self):
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
dynamic_symbolic_map = {}
for i, param in enumerate(params):
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map):
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map
def get_kernel_source(self):
return self.kernel_global_source
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel.
"""
return self.pymodule.call(self.kernels, *args, stream=stream)
def _wrap_forward_from_prebuild_lib(self,
*ins: List[torch.Tensor],
stream: Optional[int] = None):
"""High-level wrapper for kernel execution.
Handles:
1. Input validation
2. Output tensor allocation
3. Dynamic shape resolution
4. CUDA stream management
Args:
ins: Input PyTorch tensors
stream: Optional CUDA stream for asynchronous execution
Returns:
Single tensor or list of tensors containing the kernel results
"""
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
ins_idx = 0
args = []
# tensor pointers
for i in range(len(self.params)):
if i in self.result_idx:
dtype = self.param_dtypes[i]
shape = []
# Now working with native Python list, no FFI calls needed
for s in self.param_shapes[i]:
if isinstance(s, tir.Var):
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s]
shape.append(ins[ref_tensor_idx].shape[ref_shape_idx])
else: # Already converted to Python int during initialization
shape.append(s)
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
# dynamic symbolics
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
args.append(ins[buffer_idx].shape[shape_idx])
# if stream is not None, we need to pass the stream to the library
if stream is None:
if str(self.target).startswith("cuda") and torch.cuda.is_available():
stream = torch.cuda.current_stream().cuda_stream
else:
stream = 0
self._forward_from_prebuild_lib(*args, stream=stream)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable:
return self._wrap_forward_from_prebuild_lib
@property
def prim_func(self) -> tir.PrimFunc:
"""Returns the primary TIR function from the IR module."""
return retrieve_func_from_module(self.ir_module)
......@@ -46,6 +46,16 @@ extern "C" int call({}) {{
}}
"""
PREDEF_HOST_FUNC_PY = """
import cuda.bindings.driver
import ctypes
_function_names = {}
def call({}):
{}
"""
L2_PERSISTENT_MAP_CREATE_HANDLE = """
\tcudaStreamAttrValue stream_attribute;
\tsize_t init_persisting_l2_cache_size;
......@@ -94,6 +104,65 @@ TMA_DESC_INIT_FUNC = """
\t}}
"""
TMA_DESC_INIT_FUNC_PY = """
\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
\t{0}_tensorRank = {2}
\t{0}_globalAddress = {3}.data_ptr()
\t{0}_globalDim = [{4}]
\t{0}_globalStride = [{5}][1:]
\t{0}_boxDim = [{6}]
\t{0}_elementStrides = [{7}]
\t{0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({8})
\t{0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({9})
\t{0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({10})
\t{0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({11})
\tres, {0} = cuda.bindings.driver.cuTensorMapEncodeTiled(
\t\t{0}_type,
\t\t{0}_tensorRank,
\t\t{0}_globalAddress,
\t\t{0}_globalDim,
\t\t{0}_globalStride,
\t\t{0}_boxDim,
\t\t{0}_elementStrides,
\t\t{0}_interleave,
\t\t{0}_swizzle,
\t\t{0}_l2Promotion,
\t\t{0}_oobFill,
\t)
\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
\t\traise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}")
"""
KERNEL_LAUNCH_FUNC_PY = """
\tres = cuda.bindings.driver.cuKernelSetAttribute(
\t\tcuda.bindings.driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
\t\t{7},
\t\tkernels["{0}"],
\t\tcuda.bindings.driver.CUdevice({10})
\t)[0]
\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
\t\traise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}")
\tconfig = cuda.bindings.driver.CUlaunchConfig()
\tconfig.gridDimX = {1}
\tconfig.gridDimY = {2}
\tconfig.gridDimZ = {3}
\tconfig.blockDimX = {4}
\tconfig.blockDimY = {5}
\tconfig.blockDimZ = {6}
\tconfig.sharedMemBytes = {7}
\tconfig.hStream = stream
\targ_values = {8}
\targ_types = {9}
\tres = cuda.bindings.driver.cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0]
\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
\t\traise RuntimeError(f"Failed to launch kernel {0}: {{res}}")
"""
class BaseWrapper(ABC):
......@@ -470,6 +539,219 @@ class TLCUDASourceWrapper(object):
raise ValueError("Cannot find primary function in the module.")
class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"""
A wrapper class for the TileLang NVRTC backend.
"""
_TYPE_MAP = {
"float32": "ctypes.c_float",
"float16": "ctypes.c_uint16",
"bfloat16": "ctypes.c_uint16",
"e4m3_float8": "ctypes.c_uint8",
"e5m2_float8": "ctypes.c_uint8",
"float64": "ctypes.c_double",
"int64": "ctypes.c_int64",
"int32": "ctypes.c_int32",
"uint32": "ctypes.c_uint32",
"bool": "ctypes.c_bool",
"int8": "ctypes.c_int8",
"uint8": "ctypes.c_uint8",
"int16": "ctypes.c_int16",
"uint16": "ctypes.c_uint16",
"uchar": "ctypes.c_uint8",
}
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def create_dispatch_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = [{"name": "kernels", "type": "Dict[str, cuda.bindings.driver.CUkernel]"}]
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.data.name,
"type": "ctypes.c_void_p",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
function_args.append(self.get_stream_type())
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['name']}" for arg in function_args])
def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
desc_decls = []
if desc_name_map is not None:
desc_name_map[match] = name
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match:
call_args.append(
(f"{match}.data_ptr()" if arg["type"] == "ctypes.c_void_p" else match,
arg["type"]))
elif maybe_desc(arg["name"], matches, i):
call_args.append((match, "None"))
return call_args
def legalize(p):
# Convert TIR expressions to legal Python expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p)
desc_name_map: Dict[str, str] = {}
device_index = 0
kernel_launch_code = """"""
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
dynamic_smem_buf = function_info["dynamic_smem_buf"]
# Find the location of the global kernel function in the code
index = match_declare_kernel(code, function_name + "(")
# Analyze the function declaration to prepare for argument extraction
declaration = code[index:].split(";")[0]
# Identify the start of the function body to insert arguments
index = code.index("{", index)
call_args = func_call_args(declaration, function_args, desc_name_map)
for arg_name, arg_type in call_args:
if arg_type == "ctypes.c_void_p":
device_index = f"{arg_name.replace('.data_ptr()', '')}.device.index"
break
arg_names = ", ".join([arg[0] for arg in call_args])
arg_types = ", ".join([arg[1] for arg in call_args])
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += self.generate_tma_descriptor_args(
desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format(
function_name, legalize(grid_info[0]), legalize(grid_info[1]),
legalize(grid_info[2]), legalize(block_info[0]), legalize(block_info[1]),
legalize(block_info[2]), smem_str, arg_names, arg_types, device_index)
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC_PY.format(
repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
for handle_name, name in desc_name_map.items():
desc_name = name + "_desc"
assert desc_name in self.tma_descriptor_args, f"TMA descriptor {desc_name} not found in {self.tma_descriptor_args}"
args = self.tma_descriptor_args[desc_name]
# Skip __tvm_tensormap_create_tiled
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
tensor_rank = int(tensor_rank)
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
# Calculate required length for remaining_args
# 4 groups of tensor_rank size + 4 parameters
expected_args_len = 4 * tensor_rank + 4
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [str(i) for i in global_dim]
global_stride = [str(i) for i in global_stride]
box_dim = [str(i) for i in box_dim]
element_strides = [str(i) for i in element_strides]
# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
handle_name, dtype, tensor_rank, globalAddress,
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_dim)),
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_stride)),
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", box_dim)),
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})",
element_strides)), interleave, swizzle, l2Promotion, oobFill)
return tma_descripter_init
def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Organize function information for code generation
function_informations = {}
for function_name in self.function_names:
# Do not update function with dispatch host function
if (function_name not in self.block_info) or (function_name not in self.grid_info):
continue
function_informations[function_name] = {
"function_name": function_name,
"block_info": self.block_info[function_name],
"grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
}
# Create the host function wrapper for the CUDA kernel
self.host_func = self.create_dispatch_func(code, function_informations)
return self.lib_code
def get_stream_type(self) -> Dict[str, str]:
return {"name": "stream=0", "type": "int"}
class TLHIPSourceWrapper(TLCUDASourceWrapper):
"""
A wrapper class for the TileLang HIP backend.
......@@ -745,3 +1027,24 @@ class TLWrapper(BaseWrapper):
host_mod=self.host_mod,
pass_configs=self.pass_configs)
return wrapper.lib_code
class TLPyWrapper(TLWrapper):
def __init__(self, target: Target):
super().__init__(target)
def wrap(self, c_source: str):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target):
wrapper_class = TLNVRTCSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(
scheduled_ir_module=self.scheduled_ir_module,
source=c_source,
target=self.target,
device_mod=self.device_mod,
host_mod=self.host_mod,
pass_configs=self.pass_configs)
return wrapper.host_func, wrapper.function_names
from typing import List, Union, Any, Callable, Literal, Optional, Dict
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from tvm.target import Target
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tilelang.jit.adapter import (
TorchDLPackKernelAdapter,
BaseKernelAdapter,
CtypesKernelAdapter,
CythonKernelAdapter,
)
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
import tilelang
from tilelang import tvm as tvm
from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
NVRTCKernelAdapter, TorchDLPackKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import AVALIABLE_TARGETS, determine_target
class JITKernel(object):
......@@ -42,7 +39,7 @@ class JITKernel(object):
self,
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
......@@ -58,8 +55,8 @@ class JITKernel(object):
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes"], optional
Execution backend to use for kernel execution (default: "dlpack").
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Execution backend to use for kernel execution (default: "cython").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional
......@@ -99,6 +96,7 @@ class JITKernel(object):
"dlpack",
"ctypes",
"cython",
"nvrtc",
], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler
......@@ -127,7 +125,7 @@ class JITKernel(object):
target: Union[str, Target],
target_host: Union[str, Target],
out_idx: Union[List[int], int],
execution_backend: Literal["dlpack", "ctypes", "cython"],
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"],
pass_configs: Optional[Dict[str, Any]] = None,
):
"""
......@@ -240,6 +238,18 @@ class JITKernel(object):
verbose=verbose,
pass_configs=pass_configs,
)
elif execution_backend == "nvrtc":
adapter = NVRTCKernelAdapter(
params=artifact.params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
......@@ -282,6 +292,16 @@ class JITKernel(object):
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
)
elif execution_backend == "nvrtc":
adapter = NVRTCKernelAdapter.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
......@@ -334,7 +354,7 @@ class JITKernel(object):
str
The source code of the compiled kernel function.
"""
if self.execution_backend in {"ctypes", "cython"}:
if self.execution_backend in {"ctypes", "cython", "nvrtc"}:
return self.adapter.get_kernel_source()
return self.artifact.kernel_source
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment