"...composable_kernel_onnxruntime.git" did not exist on "76f3131939fb6bd0ed34cfac3be3b92c672b49e6"
Unverified Commit eb415744 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

[fix] NVRTC execution backend (#1256)

* [fix] NVRTC execution backend

* [fmt] run pre-commit

* [fix] coderabbit reviews

* [test] add cuda-python to test dep

* [fix] coderabbit reviews

* [fix] CUDA 13 compatibility

* [fix] sm90

* [fix] CUDA 13 compatibility

* [fix] pre-commit

* [fix] always use cuda::std::__atomic_ref_impl

* [fix] restore to external API

* Revert "[fix] restore to external API"

This reverts commit 49bd875638fb631d270015f408991d38fd1e9a5d.

* [fmt] use space instead tabs for py codegen

* [fix] im2col API

* [fix] revert atomic.h

* [fix] dynamic shape

* [refactor] extract common utils

* [feat] support L2 persistent map

* [fix] l2 persistent map

* [fix] pre-commit

* [fix] restore _TYPE_MAP

* [fix] pre-commit

* [fix] avoid duplicate TMA descs

* [docs] add docstring

* [fix] coderabbit

* [fix] coderabbit

* [fix] coderabbit

* [fix] coderabbit
parent 0af3fd7c
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
# CUDA specific requirements # CUDA specific requirements
flash-attn==2.5.8 flash-attn==2.5.8
cuda-python==12.9.4
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
#include <cute/arch/mma_sm80.hpp> #include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp> #include <cute/arch/mma_sm89.hpp>
#ifndef __CUDACC_RTC__
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#endif
namespace tl { namespace tl {
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
#include "../common.h" #include "../common.h"
#ifndef __CUDACC_RTC__
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#endif
namespace tl { namespace tl {
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
#include <cute/arch/mma_sm90_gmma.hpp> #include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp> #include <cute/arch/mma_sm90_gmma_ext.hpp>
#ifndef __CUDACC_RTC__
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#endif
namespace tl { namespace tl {
......
...@@ -19,6 +19,11 @@ ...@@ -19,6 +19,11 @@
#ifdef __CUDACC_RTC__ #ifdef __CUDACC_RTC__
// Disable problematic CUDA standard library headers in NVRTC environment
// Vector types (float4, uchar, etc.) are built-in to NVRTC and don't need these
// headers
#define _LIBCUDACXX___TUPLE_VECTOR_TYPES_H // Prevent vector_types.h inclusion
using int8_t = signed char; using int8_t = signed char;
using uint8_t = unsigned char; using uint8_t = unsigned char;
using int16_t = signed short; using int16_t = signed short;
...@@ -67,6 +72,24 @@ template <class T> struct is_same<T, T> : true_type {}; ...@@ -67,6 +72,24 @@ template <class T> struct is_same<T, T> : true_type {};
template <class T, class U> template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value; inline constexpr bool is_same_v = is_same<T, U>::value;
template <class T> struct is_void : false_type {};
template <> struct is_void<void> : true_type {};
template <> struct is_void<const void> : true_type {};
template <> struct is_void<volatile void> : true_type {};
template <> struct is_void<const volatile void> : true_type {};
template <class T> inline constexpr bool is_void_v = is_void<T>::value;
template <class T> struct is_pointer : false_type {};
template <class T> struct is_pointer<T *> : true_type {};
template <class T> struct is_pointer<T *const> : true_type {};
template <class T> struct is_pointer<T *volatile> : true_type {};
template <class T> struct is_pointer<T *const volatile> : true_type {};
template <class T> inline constexpr bool is_pointer_v = is_pointer<T>::value;
namespace index_sequence_impl { namespace index_sequence_impl {
// Based on https://stackoverflow.com/a/32223343/11717224 // Based on https://stackoverflow.com/a/32223343/11717224
...@@ -118,6 +141,36 @@ template <bool B, class T = void> struct enable_if {}; ...@@ -118,6 +141,36 @@ template <bool B, class T = void> struct enable_if {};
template <class T> struct enable_if<true, T> { template <class T> struct enable_if<true, T> {
using type = T; using type = T;
}; };
template <class T> struct remove_extent {
using type = T;
};
template <class T> struct remove_extent<T[]> {
using type = T;
};
template <class T, size_t N> struct remove_extent<T[N]> {
using type = T;
};
template <class T> using remove_extent_t = typename remove_extent<T>::type;
template <class T, unsigned I = 0>
struct extent : integral_constant<size_t, 0> {};
template <class T> struct extent<T[], 0> : integral_constant<size_t, 0> {};
template <class T, unsigned I> struct extent<T[], I> : extent<T, I - 1> {};
template <class T, size_t N>
struct extent<T[N], 0> : integral_constant<size_t, N> {};
template <class T, size_t N, unsigned I>
struct extent<T[N], I> : extent<T, I - 1> {};
template <class T, unsigned I = 0>
inline constexpr size_t extent_v = extent<T, I>::value;
} // namespace std } // namespace std
#endif #endif
\ No newline at end of file
#pragma once #pragma once
#include "common.h" #include "common.h"
#ifndef __CUDACC_RTC__
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
#endif
namespace tl { namespace tl {
......
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
A = torch.randn(M, K, dtype=in_dtype).cuda()
B = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def run_nvrtc_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
profiler = matmul_kernel.get_profiler()
nvrtc_latency = profiler.do_bench(func=matmul_kernel)
print(f"NVRTC Latency: {nvrtc_latency} ms")
assert nvrtc_latency is not None
tvm_latency = profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")
assert tvm_latency is not None
def test_nvrtc_kernel_do_bench():
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
def run_nvrtc_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
num_streams = 4
for _ in range(num_streams):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
matmul_kernel(tensor_a, tensor_b, tensor_c)
def test_nvrtc_kernel_multi_stream():
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)
def run_nvrtc_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_dynamic_shape():
run_nvrtc_dynamic_shape(
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_nvrtc_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
def check_hopper():
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def convolution_im2col(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
def run_nvrtc_im2col_tma_desc(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256):
"""Test im2col TMA descriptor functionality in NVRTC backend."""
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages,
num_threads)
conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
out_c = conv_kernel(a, b)
# Reference implementation using torch.conv2d
def ref_program(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=S, padding=P, dilation=D)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
ref_c = ref_program(a, b)
tilelang.testing.torch_assert_close(
out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_im2col_tma_desc():
"""Test im2col TMA descriptor with NVRTC backend."""
if not check_hopper():
import pytest
pytest.skip("Test requires Hopper GPU (compute capability 9.0)")
# Small test case for im2col TMA descriptor
run_nvrtc_im2col_tma_desc(
N=4,
C=64,
H=32,
W=32,
F=64,
K=3,
S=1,
D=1,
P=1,
block_M=64,
block_N=128,
block_K=32,
num_stages=3,
num_threads=256)
def test_nvrtc_l2_persistent_map():
"""Test L2 persistent cache annotation with elementwise add."""
from tilelang.language import annotate_l2_hit_ratio
M = 1024
N = 1024
@tilelang.jit(out_idx=[-1], execution_backend="nvrtc")
def elementwise_add_with_l2_cache(
M,
N,
block_size=256,
dtype="float32",
):
@T.prim_func
def kernel(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(M * N // block_size, threads=block_size) as bx:
# Annotate L2 persistent cache for buffer B
# B will be accessed multiple times and benefit from L2 caching
annotate_l2_hit_ratio({B: 0.8})
for i in T.serial(block_size):
idx = bx * block_size + i
if idx < M * N:
row = idx // N
col = idx % N
C[row, col] = A[row, col] + B[row, col]
return kernel
# Compile the kernel
kernel = elementwise_add_with_l2_cache(M, N)
# Create test tensors
a = torch.randn(M, N, dtype=torch.float32).cuda()
b = torch.randn(M, N, dtype=torch.float32).cuda()
# Run kernel with out_idx=[-1], C is returned not passed in
c = kernel(a, b)
# Verify correctness
ref_c = a + b
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)
print("L2 persistent map test passed!")
if __name__ == "__main__":
tilelang.testing.main()
from __future__ import annotations from __future__ import annotations
import ctypes import ctypes
import importlib
import logging import logging
import os import os
import os.path as osp
import subprocess import subprocess
import tempfile import tempfile
from typing import Any from typing import Any
...@@ -21,14 +19,6 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target ...@@ -21,14 +19,6 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
from tilelang.jit.adapter.nvrtc import is_nvrtc_available
if is_nvrtc_available:
import cuda.bindings.driver as cuda
from tilelang.contrib.nvrtc import compile_cuda
except ImportError:
is_nvrtc_available = False
class LibraryGenerator: class LibraryGenerator:
srcpath: str | None = None srcpath: str | None = None
...@@ -183,95 +173,3 @@ class LibraryGenerator: ...@@ -183,95 +173,3 @@ class LibraryGenerator:
def set_src_path(self, srcpath): def set_src_path(self, srcpath):
self.srcpath = srcpath self.srcpath = srcpath
class PyLibraryGenerator(LibraryGenerator):
host_func: str | None = None
culib = None
pymodule = None
def __init__(self, target: Target, verbose: bool = False):
if not is_nvrtc_available:
raise ImportError("cuda-python is not available, nvrtc backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the nvrtc backend.")
super().__init__(target, verbose)
@staticmethod
def import_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
self.host_func = host_func
def load_lib(self, lib_path: str | None = None):
if lib_path is None:
lib_path = self.libpath
pypath = lib_path.replace(".cubin", ".py")
self.pymodule = self.import_from_file("kernel", pypath)
# Ensure the context is valid
ctx = cuda.cuCtxGetCurrent()[1]
if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS:
import torch
torch.cuda.synchronize()
result, self.culib = cuda.cuLibraryLoadFromFile(
bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to load library: {lib_path}"
def compile_lib(self, timeout: float = None):
target = self.target
verbose = self.verbose
if is_cuda_target(target):
from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH)
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115
libpath = src.name.replace(".cu", ".cubin")
project_root = osp.join(osp.dirname(__file__), "..", "..")
if CUTLASS_INCLUDE_DIR is None:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
else:
cutlass_path = CUTLASS_INCLUDE_DIR
if TILELANG_TEMPLATE_PATH is None:
tl_template_path = osp.abspath(osp.join(project_root, "src"))
else:
tl_template_path = TILELANG_TEMPLATE_PATH
cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda"
options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"]
if self.compile_flags:
options += [
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
cubin_bytes = compile_cuda(
self.lib_code, target_format="cubin", options=options, verbose=verbose)
with open(libpath, "wb") as f:
f.write(cubin_bytes)
src.write(self.lib_code)
src.flush()
self.srcpath = src.name
self.libpath = libpath
pypath = src.name.replace(".cu", ".py")
with open(pypath, "w") as f:
f.write(self.host_func)
else:
raise ValueError(f"Unsupported target: {target}")
def __del__(self):
if self.culib:
result = cuda.cuLibraryUnload(self.culib)[0]
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"Failed to unload library: {self.libpath}")
self.culib = None
...@@ -5,7 +5,10 @@ This module provides runtime compilation support using NVIDIA's NVRTC API. ...@@ -5,7 +5,10 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
import logging import logging
__all__ = ['NVRTCKernelAdapter', 'is_nvrtc_available', 'check_nvrtc_available'] __all__ = [
'NVRTCKernelAdapter', 'TLNVRTCSourceWrapper', 'NVRTCLibraryGenerator', 'is_nvrtc_available',
'check_nvrtc_available'
]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -37,7 +40,9 @@ def check_nvrtc_available(): ...@@ -37,7 +40,9 @@ def check_nvrtc_available():
# Conditionally import the adapter # Conditionally import the adapter
if is_nvrtc_available: if is_nvrtc_available:
from .adapter import NVRTCKernelAdapter # noqa: F401 from .adapter import NVRTCKernelAdapter
from .wrapper import TLNVRTCSourceWrapper
from .libgen import NVRTCLibraryGenerator
else: else:
# Provide a dummy class that raises error on instantiation # Provide a dummy class that raises error on instantiation
class NVRTCKernelAdapter: class NVRTCKernelAdapter:
...@@ -45,3 +50,19 @@ else: ...@@ -45,3 +50,19 @@ else:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
@classmethod
def from_database(cls, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
class TLNVRTCSourceWrapper:
"""Dummy TLNVRTCSourceWrapper that raises ImportError on instantiation."""
def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
class NVRTCLibraryGenerator:
"""Dummy NVRTCLibraryGenerator that raises ImportError on instantiation."""
def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
...@@ -9,12 +9,13 @@ from tvm.target import Target ...@@ -9,12 +9,13 @@ from tvm.target import Target
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang.jit.adapter.wrapper import TLPyWrapper from tilelang.jit.adapter.wrapper import TLPyWrapper
from tilelang.jit.adapter.libgen import PyLibraryGenerator
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.jit.adapter.base import BaseKernelAdapter from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available
from .libgen import NVRTCLibraryGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Import cuda bindings if available # Import cuda bindings if available
...@@ -75,7 +76,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -75,7 +76,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_device_module(device_mod) self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source) self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source)
self.lib_generator = PyLibraryGenerator(self.target, self.verbose) self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.kernel_global_source) self.lib_generator.update_lib_code(self.kernel_global_source)
self.lib_generator.update_host_func(self.host_func) self.lib_generator.update_host_func(self.host_func)
self.lib_generator.assign_compile_flags(compile_flags) self.lib_generator.assign_compile_flags(compile_flags)
...@@ -130,7 +131,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -130,7 +131,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter.target = Target.canon_target(determine_target(target)) adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose adapter.verbose = verbose
adapter.lib_generator = PyLibraryGenerator(adapter.target, adapter.verbose) adapter.lib_generator = NVRTCLibraryGenerator(adapter.target, adapter.verbose)
adapter.lib_generator.assign_compile_flags(compile_flags) adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.pymodule = adapter.lib_generator.pymodule adapter.pymodule = adapter.lib_generator.pymodule
......
"""NVRTC Library Generator for TileLang.
Compiles CUDA kernels at runtime using NVRTC and manages resulting binaries.
Why NVRTC instead of nvcc:
- No offline compilation step, enables true JIT workflows
- Works without CUDA toolkit installed (only requires driver)
- Allows kernel specialization based on runtime parameters
Key responsibilities:
- Compile CUDA source to cubin using NVRTC API
- Generate accompanying Python launcher code
- Load compiled cubin and extract kernel handles
- Manage library lifecycle (load/unload)
"""
from __future__ import annotations
import importlib
import logging
import os.path as osp
import platform
import tempfile
from types import ModuleType
from tvm.target import Target
from tilelang import tvm as tvm
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target
from tilelang.jit.adapter.nvrtc import is_nvrtc_available, NVRTC_UNAVAILABLE_MESSAGE
logger = logging.getLogger(__name__)
if is_nvrtc_available:
import cuda.bindings.driver as cuda
from tilelang.contrib.nvrtc import compile_cuda
else:
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
class NVRTCLibraryGenerator(LibraryGenerator):
"""Runtime compiler and loader for NVRTC-compiled CUDA kernels.
Lifecycle:
1. compile_lib(): CUDA source → cubin + Python launcher
2. load_lib(): cubin → loaded library + kernel handles
3. pymodule.call(): Execute kernels via Python launcher
4. __del__: Cleanup (unload library)
Why three files (cu, cubin, py):
- .cu: Source for debugging, kept in temp directory
- .cubin: Compiled binary, loaded by CUDA driver
- .py: Launch code, imported as Python module
Attributes:
host_func: Generated Python launch code (from wrapper)
culib: CUDA library handle (CUlibrary)
pymodule: Imported Python module containing call() function
"""
host_func: str | None = None
culib: cuda.CUlibrary | None = None
pymodule: ModuleType | None = None
pypath: str | None = None
def __init__(self, target: Target, verbose: bool = False):
"""Initialize NVRTC library generator.
Args:
target: Compilation target (must be CUDA)
verbose: Enable verbose compilation output
"""
super().__init__(target, verbose)
@staticmethod
def import_from_file(module_name, file_path):
"""Dynamically import Python module from file path.
Standard importlib pattern for loading modules outside sys.path.
Used to import generated .py launcher code from temp directory.
Args:
module_name: Name to assign to imported module
file_path: Absolute path to .py file
Returns:
Imported module object
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None or spec.loader is None:
raise ImportError(f"Failed to import module from file: {file_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
"""Store generated Python launch code for later file write.
Called by adapter after wrapper generates the launch code.
This is the bridge between code generation and file output.
Args:
host_func: Python source code containing call() function
"""
self.host_func = host_func
def load_lib(self, lib_path: str | None = None):
"""Load compiled cubin and Python launcher into memory.
Why two loads:
1. Import Python module for launch logic
2. Load cubin via CUDA Driver API for kernel handles
Context synchronization: CUDA context must be current before loading.
If not, use torch.cuda.synchronize() to establish context.
Args:
lib_path: Path to .cubin file (optional, uses self.libpath if None)
Side effects:
- Sets self.pymodule to imported Python module
- Sets self.culib to CUDA library handle
"""
if lib_path is None:
lib_path = self.libpath
else:
self.libpath = lib_path
self.pypath = lib_path.replace(".cubin", ".py")
self.pymodule = self.import_from_file("kernel", self.pypath)
# Ensure the context is valid
ctx = cuda.cuCtxGetCurrent()[1]
if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS:
import torch
torch.cuda.synchronize()
result, self.culib = cuda.cuLibraryLoadFromFile(
bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
if result != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}")
def compile_lib(self, timeout: float | None = None):
"""Compile CUDA source to cubin using NVRTC and write output files.
Output artifacts (all in temp directory):
- .cu: Source code (for debugging)
- .cubin: Compiled binary (for execution)
- .py: Python launcher (for calling kernels)
Include paths setup:
- TileLang templates: kernel primitives and utilities
- CUTLASS: optimized GEMM/tensor ops
- CUDA headers: driver/runtime APIs
Why architecture detection:
ARM64 servers (SBSA) have different header paths than x86_64.
Args:
timeout: Compilation timeout in seconds (currently unsupported by NVRTC compiler)
Side effects:
- Writes .cu, .cubin, .py files to temp directory
- Sets self.srcpath, self.libpath, self.pypath
"""
target = self.target
verbose = self.verbose
if is_cuda_target(target):
from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH)
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
libpath = src.name.replace(".cu", ".cubin")
project_root = osp.join(osp.dirname(__file__), "..", "..")
if CUTLASS_INCLUDE_DIR is None:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
else:
cutlass_path = CUTLASS_INCLUDE_DIR
if TILELANG_TEMPLATE_PATH is None:
tl_template_path = osp.abspath(osp.join(project_root, "src"))
else:
tl_template_path = TILELANG_TEMPLATE_PATH
cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda"
__CUDACC_VER_MAJOR__ = cuda.CUDA_VERSION // 1000
# Determine target architecture
machine = platform.machine()
target_arch = "sbsa-linux" if machine in ("aarch64", "arm64") else "x86_64-linux"
options = [
f"-I{tl_template_path}",
f"-I{cutlass_path}",
f"-I{cuda_home}/include",
f"-I{cuda_home}/targets/{target_arch}/include",
f"-I{cuda_home}/targets/{target_arch}/include/cccl",
f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}",
]
if self.compile_flags:
options += [
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
cubin_bytes = compile_cuda(
self.lib_code, target_format="cubin", options=options, verbose=verbose)
with open(libpath, "wb") as f:
f.write(cubin_bytes)
src.write(self.lib_code)
src.flush()
self.srcpath = src.name
self.libpath = libpath
self.pypath = src.name.replace(".cu", ".py")
if self.host_func is None:
raise RuntimeError(
"Host function is not set, please call update_host_func() first.")
with open(self.pypath, "w") as f:
f.write(self.host_func)
else:
raise ValueError(f"Unsupported target: {target}")
def __del__(self):
"""Cleanup: unload CUDA library when object is destroyed.
Critical for resource management - CUDA libraries consume GPU memory.
Failure to unload is logged but not raised (destructor can't fail).
Why explicit unload:
Python GC doesn't know about GPU resources, must release manually.
"""
if self.culib:
result = cuda.cuLibraryUnload(self.culib)[0]
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"Failed to unload library: {self.libpath}")
self.culib = None
This diff is collapsed.
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Literal from typing import Literal, Callable, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule, tir from tvm import IRModule, tir
from tvm.target import Target from tvm.target import Target
...@@ -107,13 +107,16 @@ def get_annotated_mod( ...@@ -107,13 +107,16 @@ def get_annotated_mod(
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str: def pythonic_expr(expr: tvm.tir.PrimExpr,
dtype_map: dict[str, str] | None = None,
ignore_cast: bool = False) -> str:
""" """
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
Args: Args:
expr: The TVM PrimExpr to convert. expr: The TVM PrimExpr to convert.
dtype_map: A dictionary mapping data types to their string representations.
ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast.
Returns: Returns:
A string representation of the expression. A string representation of the expression.
""" """
...@@ -158,10 +161,11 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non ...@@ -158,10 +161,11 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
elif isinstance(node, tvm.tir.Cast): elif isinstance(node, tvm.tir.Cast):
# C-style cast has high precedence # C-style cast has high precedence
value_str, _ = node_to_result_map[node.value] value_str, _ = node_to_result_map[node.value]
if dtype_map is None: if ignore_cast:
s = f"({node.dtype}){value_str}" s = value_str
else: else:
s = f"({dtype_map[node.dtype]}){value_str}" type_str = node.dtype if dtype_map is None else dtype_map[node.dtype]
s = f"({type_str}){value_str}"
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance( elif isinstance(
node, node,
...@@ -216,3 +220,238 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non ...@@ -216,3 +220,238 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
tvm.tir.stmt_functor.post_order_visit(expr, _visitor) tvm.tir.stmt_functor.post_order_visit(expr, _visitor)
return next(iter(node_to_result_map[expr]), "") return next(iter(node_to_result_map[expr]), "")
def maybe_desc_name(name: str,
matches: list[str],
i: int,
desc_name_map: dict[str, str] | None = None) -> bool:
"""
Check if a parameter name corresponds to a TMA descriptor.
Args:
name: The parameter name to check.
matches: List of all matched parameter names.
i: Index of the current match.
desc_name_map: Optional mapping to store descriptor name relationships.
Returns:
True if the parameter is a TMA descriptor.
"""
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
desc_decls = []
if desc_name_map is not None:
desc_name_map[match] = name
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
def parse_function_call_args(
declaration: str,
function_args: list[dict[str, str]],
function_params: list[Any],
desc_name_map: dict[str, str] | None = None,
desc_name_var_map: dict[str, tvm.tir.Var] | None = None,
transform_arg: Callable[[str, str], Any] | None = None,
) -> list[Any]:
"""
Parse function call arguments from a kernel declaration.
Args:
declaration: The kernel function declaration string.
function_args: List of function argument specifications.
function_params: List of function parameters from TVM IR.
desc_name_map: Optional mapping for descriptor names.
desc_name_var_map: Optional mapping from descriptor names to TVM variables.
transform_arg: Optional function to transform each argument (name, type) -> result.
Returns:
List of parsed call arguments.
"""
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, declaration)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match:
if transform_arg is not None:
call_args.append(transform_arg(match, arg["type"]))
else:
call_args.append(match)
elif maybe_desc_name(arg["name"], matches, i, desc_name_map):
if transform_arg is not None:
call_args.append(transform_arg(match, "None"))
else:
call_args.append(match)
if desc_name_var_map is not None and function_params is not None:
assert len(call_args) <= len(function_params), \
f"Too many arguments: {len(call_args)} > {len(function_params)}"
desc_name_var_map[match] = function_params[len(call_args) - 1]
return call_args
class TMADescriptorParams:
"""Parsed TMA descriptor parameters."""
def __init__(self,
handle_name: str,
dtype: str,
tensor_rank: int,
global_address: Any,
is_img2col: bool = False):
self.handle_name = handle_name
self.dtype = dtype
self.tensor_rank = tensor_rank
self.global_address = global_address
self.is_img2col = is_img2col
# Common fields
self.global_dim: list[str] = []
self.global_stride: list[str] = []
self.element_strides: list[str] = []
self.interleave: str = ""
self.swizzle: str = ""
self.l2_promotion: str = ""
self.oob_fill: str = ""
# Tiled-specific fields
self.box_dim: list[str] = []
# Im2col-specific fields
self.lower_corner: list[str] = []
self.upper_corner: list[str] = []
self.smem_box_channel: str = ""
self.smem_box_pixel: str = ""
def parse_tma_descriptor_args(
tma_descriptor_args: dict[tvm.tir.Var, list[Any]],
desc_name_map: dict[str, str],
desc_name_var_map: dict[str, tvm.tir.Var],
pythonic_expr_func: Callable[[Any], str],
) -> list[TMADescriptorParams]:
"""
Parse TMA descriptor arguments into structured parameters.
Args:
tma_descriptor_args: Dictionary mapping TMA descriptor variables to their arguments.
desc_name_map: Mapping from descriptor handles to parameter names.
desc_name_var_map: Mapping from descriptor handles to TVM variables.
pythonic_expr_func: Function to convert TVM expressions to strings.
Returns:
List of parsed TMA descriptor parameters.
"""
if not tma_descriptor_args:
return []
results = []
for handle_name, _ in desc_name_map.items():
assert handle_name in desc_name_var_map, \
f"Handle name {handle_name} not found in desc_name_var_map"
desc_var = desc_name_var_map[handle_name]
assert desc_var in tma_descriptor_args, \
f"TMA descriptor {desc_var} not found in {tma_descriptor_args}"
args = tma_descriptor_args[desc_var]
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args
is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
# Convert basic fields
dtype = pythonic_expr_func(dtype)
tensor_rank = int(pythonic_expr_func(tensor_rank))
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
params = TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col)
if not is_img2col:
# Tiled mode
expected_args_len = 4 * tensor_rank + 4
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank]
]
params.box_dim = [
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.element_strides = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank]
]
# Extract remaining parameters
try:
interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
params.interleave = pythonic_expr_func(interleave)
params.swizzle = pythonic_expr_func(swizzle)
params.l2_promotion = pythonic_expr_func(l2_promotion)
params.oob_fill = pythonic_expr_func(oob_fill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
else:
# Im2col mode
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank]
]
params.element_strides = [
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.lower_corner = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
]
params.upper_corner = [
pythonic_expr_func(i)
for i in remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
]
# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = \
remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2]
params.smem_box_pixel = pythonic_expr_func(smem_box_pixel)
params.smem_box_channel = pythonic_expr_func(smem_box_channel)
params.interleave = pythonic_expr_func(interleave)
params.swizzle = pythonic_expr_func(swizzle)
params.l2_promotion = pythonic_expr_func(l2_promotion)
params.oob_fill = pythonic_expr_func(oob_fill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 6 TMA parameters "
"(smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
) from e
results.append(params)
return results
This diff is collapsed.
...@@ -15,7 +15,7 @@ from tilelang import tvm ...@@ -15,7 +15,7 @@ from tilelang import tvm
from tilelang import env from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter) TorchDLPackKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc from tilelang.contrib import nvcc as tl_nvcc
...@@ -270,6 +270,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -270,6 +270,7 @@ class JITKernel(Generic[_P, _T]):
compile_flags=compile_flags, compile_flags=compile_flags,
) )
elif execution_backend == "nvrtc": elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter( adapter = NVRTCKernelAdapter(
params=artifact.params, params=artifact.params,
result_idx=out_idx, result_idx=out_idx,
...@@ -339,6 +340,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -339,6 +340,7 @@ class JITKernel(Generic[_P, _T]):
pass_configs=pass_configs, pass_configs=pass_configs,
) )
elif execution_backend == "nvrtc": elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter.from_database( adapter = NVRTCKernelAdapter.from_database(
params=params, params=params,
result_idx=result_idx, result_idx=result_idx,
......
...@@ -5,6 +5,7 @@ from typing import Callable ...@@ -5,6 +5,7 @@ from typing import Callable
from tilelang.layout import Layout from tilelang.layout import Layout
from tvm.script.parser.tir import attr, block_attr from tvm.script.parser.tir import attr, block_attr
from tvm.tir import FloatImm
__all__ = [ __all__ = [
"use_swizzle", "use_swizzle",
...@@ -49,5 +50,5 @@ def annotate_l2_hit_ratio(l2_hit_ratio_map: dict): ...@@ -49,5 +50,5 @@ def annotate_l2_hit_ratio(l2_hit_ratio_map: dict):
_l2_hit_ratio_map = {} _l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items(): for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers" assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = float(hit_ratio) _l2_hit_ratio_map[buffer.data] = FloatImm("float32", float(hit_ratio))
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map}) return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment