"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "875a0da84c4a7e02e49bc863665e486baf520510"
Unverified Commit eb415744 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

[fix] NVRTC execution backend (#1256)

* [fix] NVRTC execution backend

* [fmt] run pre-commit

* [fix] coderabbit reviews

* [test] add cuda-python to test dep

* [fix] coderabbit reviews

* [fix] CUDA 13 compatibility

* [fix] sm90

* [fix] CUDA 13 compatibility

* [fix] pre-commit

* [fix] always use cuda::std::__atomic_ref_impl

* [fix] restore to external API

* Revert "[fix] restore to external API"

This reverts commit 49bd875638fb631d270015f408991d38fd1e9a5d.

* [fmt] use space instead tabs for py codegen

* [fix] im2col API

* [fix] revert atomic.h

* [fix] dynamic shape

* [refactor] extract common utils

* [feat] support L2 persistent map

* [fix] l2 persistent map

* [fix] pre-commit

* [fix] restore _TYPE_MAP

* [fix] pre-commit

* [fix] avoid duplicate TMA descs

* [docs] add docstring

* [fix] coderabbit

* [fix] coderabbit

* [fix] coderabbit

* [fix] coderabbit
parent 0af3fd7c
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
# CUDA specific requirements # CUDA specific requirements
flash-attn==2.5.8 flash-attn==2.5.8
cuda-python==12.9.4
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
#include <cute/arch/mma_sm80.hpp> #include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp> #include <cute/arch/mma_sm89.hpp>
#ifndef __CUDACC_RTC__
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#endif
namespace tl { namespace tl {
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
#include "../common.h" #include "../common.h"
#ifndef __CUDACC_RTC__
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#endif
namespace tl { namespace tl {
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
#include <cute/arch/mma_sm90_gmma.hpp> #include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp> #include <cute/arch/mma_sm90_gmma_ext.hpp>
#ifndef __CUDACC_RTC__
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#endif
namespace tl { namespace tl {
......
...@@ -19,6 +19,11 @@ ...@@ -19,6 +19,11 @@
#ifdef __CUDACC_RTC__ #ifdef __CUDACC_RTC__
// Disable problematic CUDA standard library headers in NVRTC environment
// Vector types (float4, uchar, etc.) are built-in to NVRTC and don't need these
// headers
#define _LIBCUDACXX___TUPLE_VECTOR_TYPES_H // Prevent vector_types.h inclusion
using int8_t = signed char; using int8_t = signed char;
using uint8_t = unsigned char; using uint8_t = unsigned char;
using int16_t = signed short; using int16_t = signed short;
...@@ -67,6 +72,24 @@ template <class T> struct is_same<T, T> : true_type {}; ...@@ -67,6 +72,24 @@ template <class T> struct is_same<T, T> : true_type {};
template <class T, class U> template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value; inline constexpr bool is_same_v = is_same<T, U>::value;
template <class T> struct is_void : false_type {};
template <> struct is_void<void> : true_type {};
template <> struct is_void<const void> : true_type {};
template <> struct is_void<volatile void> : true_type {};
template <> struct is_void<const volatile void> : true_type {};
template <class T> inline constexpr bool is_void_v = is_void<T>::value;
template <class T> struct is_pointer : false_type {};
template <class T> struct is_pointer<T *> : true_type {};
template <class T> struct is_pointer<T *const> : true_type {};
template <class T> struct is_pointer<T *volatile> : true_type {};
template <class T> struct is_pointer<T *const volatile> : true_type {};
template <class T> inline constexpr bool is_pointer_v = is_pointer<T>::value;
namespace index_sequence_impl { namespace index_sequence_impl {
// Based on https://stackoverflow.com/a/32223343/11717224 // Based on https://stackoverflow.com/a/32223343/11717224
...@@ -118,6 +141,36 @@ template <bool B, class T = void> struct enable_if {}; ...@@ -118,6 +141,36 @@ template <bool B, class T = void> struct enable_if {};
template <class T> struct enable_if<true, T> { template <class T> struct enable_if<true, T> {
using type = T; using type = T;
}; };
template <class T> struct remove_extent {
using type = T;
};
template <class T> struct remove_extent<T[]> {
using type = T;
};
template <class T, size_t N> struct remove_extent<T[N]> {
using type = T;
};
template <class T> using remove_extent_t = typename remove_extent<T>::type;
template <class T, unsigned I = 0>
struct extent : integral_constant<size_t, 0> {};
template <class T> struct extent<T[], 0> : integral_constant<size_t, 0> {};
template <class T, unsigned I> struct extent<T[], I> : extent<T, I - 1> {};
template <class T, size_t N>
struct extent<T[N], 0> : integral_constant<size_t, N> {};
template <class T, size_t N, unsigned I>
struct extent<T[N], I> : extent<T, I - 1> {};
template <class T, unsigned I = 0>
inline constexpr size_t extent_v = extent<T, I>::value;
} // namespace std } // namespace std
#endif #endif
\ No newline at end of file
#pragma once #pragma once
#include "common.h" #include "common.h"
#ifndef __CUDACC_RTC__
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
#endif
namespace tl { namespace tl {
......
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
A = torch.randn(M, K, dtype=in_dtype).cuda()
B = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def run_nvrtc_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
profiler = matmul_kernel.get_profiler()
nvrtc_latency = profiler.do_bench(func=matmul_kernel)
print(f"NVRTC Latency: {nvrtc_latency} ms")
assert nvrtc_latency is not None
tvm_latency = profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")
assert tvm_latency is not None
def test_nvrtc_kernel_do_bench():
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
def run_nvrtc_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
num_streams = 4
for _ in range(num_streams):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
matmul_kernel(tensor_a, tensor_b, tensor_c)
def test_nvrtc_kernel_multi_stream():
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)
def run_nvrtc_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_dynamic_shape():
run_nvrtc_dynamic_shape(
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_nvrtc_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
def check_hopper():
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def convolution_im2col(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
def run_nvrtc_im2col_tma_desc(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256):
"""Test im2col TMA descriptor functionality in NVRTC backend."""
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages,
num_threads)
conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
out_c = conv_kernel(a, b)
# Reference implementation using torch.conv2d
def ref_program(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=S, padding=P, dilation=D)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
ref_c = ref_program(a, b)
tilelang.testing.torch_assert_close(
out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_im2col_tma_desc():
"""Test im2col TMA descriptor with NVRTC backend."""
if not check_hopper():
import pytest
pytest.skip("Test requires Hopper GPU (compute capability 9.0)")
# Small test case for im2col TMA descriptor
run_nvrtc_im2col_tma_desc(
N=4,
C=64,
H=32,
W=32,
F=64,
K=3,
S=1,
D=1,
P=1,
block_M=64,
block_N=128,
block_K=32,
num_stages=3,
num_threads=256)
def test_nvrtc_l2_persistent_map():
"""Test L2 persistent cache annotation with elementwise add."""
from tilelang.language import annotate_l2_hit_ratio
M = 1024
N = 1024
@tilelang.jit(out_idx=[-1], execution_backend="nvrtc")
def elementwise_add_with_l2_cache(
M,
N,
block_size=256,
dtype="float32",
):
@T.prim_func
def kernel(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(M * N // block_size, threads=block_size) as bx:
# Annotate L2 persistent cache for buffer B
# B will be accessed multiple times and benefit from L2 caching
annotate_l2_hit_ratio({B: 0.8})
for i in T.serial(block_size):
idx = bx * block_size + i
if idx < M * N:
row = idx // N
col = idx % N
C[row, col] = A[row, col] + B[row, col]
return kernel
# Compile the kernel
kernel = elementwise_add_with_l2_cache(M, N)
# Create test tensors
a = torch.randn(M, N, dtype=torch.float32).cuda()
b = torch.randn(M, N, dtype=torch.float32).cuda()
# Run kernel with out_idx=[-1], C is returned not passed in
c = kernel(a, b)
# Verify correctness
ref_c = a + b
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)
print("L2 persistent map test passed!")
if __name__ == "__main__":
tilelang.testing.main()
from __future__ import annotations from __future__ import annotations
import ctypes import ctypes
import importlib
import logging import logging
import os import os
import os.path as osp
import subprocess import subprocess
import tempfile import tempfile
from typing import Any from typing import Any
...@@ -21,14 +19,6 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target ...@@ -21,14 +19,6 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
from tilelang.jit.adapter.nvrtc import is_nvrtc_available
if is_nvrtc_available:
import cuda.bindings.driver as cuda
from tilelang.contrib.nvrtc import compile_cuda
except ImportError:
is_nvrtc_available = False
class LibraryGenerator: class LibraryGenerator:
srcpath: str | None = None srcpath: str | None = None
...@@ -183,95 +173,3 @@ class LibraryGenerator: ...@@ -183,95 +173,3 @@ class LibraryGenerator:
def set_src_path(self, srcpath): def set_src_path(self, srcpath):
self.srcpath = srcpath self.srcpath = srcpath
class PyLibraryGenerator(LibraryGenerator):
host_func: str | None = None
culib = None
pymodule = None
def __init__(self, target: Target, verbose: bool = False):
if not is_nvrtc_available:
raise ImportError("cuda-python is not available, nvrtc backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the nvrtc backend.")
super().__init__(target, verbose)
@staticmethod
def import_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
self.host_func = host_func
def load_lib(self, lib_path: str | None = None):
if lib_path is None:
lib_path = self.libpath
pypath = lib_path.replace(".cubin", ".py")
self.pymodule = self.import_from_file("kernel", pypath)
# Ensure the context is valid
ctx = cuda.cuCtxGetCurrent()[1]
if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS:
import torch
torch.cuda.synchronize()
result, self.culib = cuda.cuLibraryLoadFromFile(
bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
assert result == cuda.CUresult.CUDA_SUCCESS, f"Failed to load library: {lib_path}"
def compile_lib(self, timeout: float = None):
target = self.target
verbose = self.verbose
if is_cuda_target(target):
from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH)
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115
libpath = src.name.replace(".cu", ".cubin")
project_root = osp.join(osp.dirname(__file__), "..", "..")
if CUTLASS_INCLUDE_DIR is None:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
else:
cutlass_path = CUTLASS_INCLUDE_DIR
if TILELANG_TEMPLATE_PATH is None:
tl_template_path = osp.abspath(osp.join(project_root, "src"))
else:
tl_template_path = TILELANG_TEMPLATE_PATH
cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda"
options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"]
if self.compile_flags:
options += [
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
cubin_bytes = compile_cuda(
self.lib_code, target_format="cubin", options=options, verbose=verbose)
with open(libpath, "wb") as f:
f.write(cubin_bytes)
src.write(self.lib_code)
src.flush()
self.srcpath = src.name
self.libpath = libpath
pypath = src.name.replace(".cu", ".py")
with open(pypath, "w") as f:
f.write(self.host_func)
else:
raise ValueError(f"Unsupported target: {target}")
def __del__(self):
if self.culib:
result = cuda.cuLibraryUnload(self.culib)[0]
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"Failed to unload library: {self.libpath}")
self.culib = None
...@@ -5,7 +5,10 @@ This module provides runtime compilation support using NVIDIA's NVRTC API. ...@@ -5,7 +5,10 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
import logging import logging
__all__ = ['NVRTCKernelAdapter', 'is_nvrtc_available', 'check_nvrtc_available'] __all__ = [
'NVRTCKernelAdapter', 'TLNVRTCSourceWrapper', 'NVRTCLibraryGenerator', 'is_nvrtc_available',
'check_nvrtc_available'
]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -37,7 +40,9 @@ def check_nvrtc_available(): ...@@ -37,7 +40,9 @@ def check_nvrtc_available():
# Conditionally import the adapter # Conditionally import the adapter
if is_nvrtc_available: if is_nvrtc_available:
from .adapter import NVRTCKernelAdapter # noqa: F401 from .adapter import NVRTCKernelAdapter
from .wrapper import TLNVRTCSourceWrapper
from .libgen import NVRTCLibraryGenerator
else: else:
# Provide a dummy class that raises error on instantiation # Provide a dummy class that raises error on instantiation
class NVRTCKernelAdapter: class NVRTCKernelAdapter:
...@@ -45,3 +50,19 @@ else: ...@@ -45,3 +50,19 @@ else:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE) raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
@classmethod
def from_database(cls, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
class TLNVRTCSourceWrapper:
"""Dummy TLNVRTCSourceWrapper that raises ImportError on instantiation."""
def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
class NVRTCLibraryGenerator:
"""Dummy NVRTCLibraryGenerator that raises ImportError on instantiation."""
def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
...@@ -9,12 +9,13 @@ from tvm.target import Target ...@@ -9,12 +9,13 @@ from tvm.target import Target
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang.jit.adapter.wrapper import TLPyWrapper from tilelang.jit.adapter.wrapper import TLPyWrapper
from tilelang.jit.adapter.libgen import PyLibraryGenerator
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.jit.adapter.base import BaseKernelAdapter from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available
from .libgen import NVRTCLibraryGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Import cuda bindings if available # Import cuda bindings if available
...@@ -75,7 +76,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -75,7 +76,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_device_module(device_mod) self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source) self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source)
self.lib_generator = PyLibraryGenerator(self.target, self.verbose) self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.kernel_global_source) self.lib_generator.update_lib_code(self.kernel_global_source)
self.lib_generator.update_host_func(self.host_func) self.lib_generator.update_host_func(self.host_func)
self.lib_generator.assign_compile_flags(compile_flags) self.lib_generator.assign_compile_flags(compile_flags)
...@@ -130,7 +131,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -130,7 +131,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter.target = Target.canon_target(determine_target(target)) adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose adapter.verbose = verbose
adapter.lib_generator = PyLibraryGenerator(adapter.target, adapter.verbose) adapter.lib_generator = NVRTCLibraryGenerator(adapter.target, adapter.verbose)
adapter.lib_generator.assign_compile_flags(compile_flags) adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.pymodule = adapter.lib_generator.pymodule adapter.pymodule = adapter.lib_generator.pymodule
......
"""NVRTC Library Generator for TileLang.
Compiles CUDA kernels at runtime using NVRTC and manages resulting binaries.
Why NVRTC instead of nvcc:
- No offline compilation step, enables true JIT workflows
- Works without CUDA toolkit installed (only requires driver)
- Allows kernel specialization based on runtime parameters
Key responsibilities:
- Compile CUDA source to cubin using NVRTC API
- Generate accompanying Python launcher code
- Load compiled cubin and extract kernel handles
- Manage library lifecycle (load/unload)
"""
from __future__ import annotations
import importlib
import logging
import os.path as osp
import platform
import tempfile
from types import ModuleType
from tvm.target import Target
from tilelang import tvm as tvm
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target
from tilelang.jit.adapter.nvrtc import is_nvrtc_available, NVRTC_UNAVAILABLE_MESSAGE
logger = logging.getLogger(__name__)
if is_nvrtc_available:
import cuda.bindings.driver as cuda
from tilelang.contrib.nvrtc import compile_cuda
else:
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
class NVRTCLibraryGenerator(LibraryGenerator):
"""Runtime compiler and loader for NVRTC-compiled CUDA kernels.
Lifecycle:
1. compile_lib(): CUDA source → cubin + Python launcher
2. load_lib(): cubin → loaded library + kernel handles
3. pymodule.call(): Execute kernels via Python launcher
4. __del__: Cleanup (unload library)
Why three files (cu, cubin, py):
- .cu: Source for debugging, kept in temp directory
- .cubin: Compiled binary, loaded by CUDA driver
- .py: Launch code, imported as Python module
Attributes:
host_func: Generated Python launch code (from wrapper)
culib: CUDA library handle (CUlibrary)
pymodule: Imported Python module containing call() function
"""
host_func: str | None = None
culib: cuda.CUlibrary | None = None
pymodule: ModuleType | None = None
pypath: str | None = None
def __init__(self, target: Target, verbose: bool = False):
"""Initialize NVRTC library generator.
Args:
target: Compilation target (must be CUDA)
verbose: Enable verbose compilation output
"""
super().__init__(target, verbose)
@staticmethod
def import_from_file(module_name, file_path):
"""Dynamically import Python module from file path.
Standard importlib pattern for loading modules outside sys.path.
Used to import generated .py launcher code from temp directory.
Args:
module_name: Name to assign to imported module
file_path: Absolute path to .py file
Returns:
Imported module object
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None or spec.loader is None:
raise ImportError(f"Failed to import module from file: {file_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def update_host_func(self, host_func: str):
"""Store generated Python launch code for later file write.
Called by adapter after wrapper generates the launch code.
This is the bridge between code generation and file output.
Args:
host_func: Python source code containing call() function
"""
self.host_func = host_func
def load_lib(self, lib_path: str | None = None):
"""Load compiled cubin and Python launcher into memory.
Why two loads:
1. Import Python module for launch logic
2. Load cubin via CUDA Driver API for kernel handles
Context synchronization: CUDA context must be current before loading.
If not, use torch.cuda.synchronize() to establish context.
Args:
lib_path: Path to .cubin file (optional, uses self.libpath if None)
Side effects:
- Sets self.pymodule to imported Python module
- Sets self.culib to CUDA library handle
"""
if lib_path is None:
lib_path = self.libpath
else:
self.libpath = lib_path
self.pypath = lib_path.replace(".cubin", ".py")
self.pymodule = self.import_from_file("kernel", self.pypath)
# Ensure the context is valid
ctx = cuda.cuCtxGetCurrent()[1]
if cuda.cuCtxGetApiVersion(ctx)[0] != cuda.CUresult.CUDA_SUCCESS:
import torch
torch.cuda.synchronize()
result, self.culib = cuda.cuLibraryLoadFromFile(
bytes(lib_path, "utf-8"), [], [], 0, [], [], 0)
if result != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to load library: {lib_path}, error: {result}")
def compile_lib(self, timeout: float | None = None):
"""Compile CUDA source to cubin using NVRTC and write output files.
Output artifacts (all in temp directory):
- .cu: Source code (for debugging)
- .cubin: Compiled binary (for execution)
- .py: Python launcher (for calling kernels)
Include paths setup:
- TileLang templates: kernel primitives and utilities
- CUTLASS: optimized GEMM/tensor ops
- CUDA headers: driver/runtime APIs
Why architecture detection:
ARM64 servers (SBSA) have different header paths than x86_64.
Args:
timeout: Compilation timeout in seconds (currently unsupported by NVRTC compiler)
Side effects:
- Writes .cu, .cubin, .py files to temp directory
- Sets self.srcpath, self.libpath, self.pypath
"""
target = self.target
verbose = self.verbose
if is_cuda_target(target):
from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH)
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
libpath = src.name.replace(".cu", ".cubin")
project_root = osp.join(osp.dirname(__file__), "..", "..")
if CUTLASS_INCLUDE_DIR is None:
cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
else:
cutlass_path = CUTLASS_INCLUDE_DIR
if TILELANG_TEMPLATE_PATH is None:
tl_template_path = osp.abspath(osp.join(project_root, "src"))
else:
tl_template_path = TILELANG_TEMPLATE_PATH
cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda"
__CUDACC_VER_MAJOR__ = cuda.CUDA_VERSION // 1000
# Determine target architecture
machine = platform.machine()
target_arch = "sbsa-linux" if machine in ("aarch64", "arm64") else "x86_64-linux"
options = [
f"-I{tl_template_path}",
f"-I{cutlass_path}",
f"-I{cuda_home}/include",
f"-I{cuda_home}/targets/{target_arch}/include",
f"-I{cuda_home}/targets/{target_arch}/include/cccl",
f"-D__CUDACC_VER_MAJOR__={__CUDACC_VER_MAJOR__}",
]
if self.compile_flags:
options += [
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
cubin_bytes = compile_cuda(
self.lib_code, target_format="cubin", options=options, verbose=verbose)
with open(libpath, "wb") as f:
f.write(cubin_bytes)
src.write(self.lib_code)
src.flush()
self.srcpath = src.name
self.libpath = libpath
self.pypath = src.name.replace(".cu", ".py")
if self.host_func is None:
raise RuntimeError(
"Host function is not set, please call update_host_func() first.")
with open(self.pypath, "w") as f:
f.write(self.host_func)
else:
raise ValueError(f"Unsupported target: {target}")
def __del__(self):
"""Cleanup: unload CUDA library when object is destroyed.
Critical for resource management - CUDA libraries consume GPU memory.
Failure to unload is logged but not raised (destructor can't fail).
Why explicit unload:
Python GC doesn't know about GPU resources, must release manually.
"""
if self.culib:
result = cuda.cuLibraryUnload(self.culib)[0]
if result != cuda.CUresult.CUDA_SUCCESS:
logger.warning(f"Failed to unload library: {self.libpath}")
self.culib = None
"""NVRTC Source Wrapper for TileLang.
Generates Python runtime code for launching CUDA kernels compiled via NVRTC.
Why this exists:
- NVRTC compiles kernels at runtime, needs Python launch code (not C++)
- TMA descriptors must be initialized once per unique buffer, not per kernel
- L2 cache policies require explicit CUDA Driver API setup/teardown
Key design:
- Two-pass generation: collect all descriptors first, then generate launches
- Dict-based deduplication ensures TMA descriptors created only once
- Generates pure Python using cuda.bindings.driver for zero C++ dependency
"""
from __future__ import annotations
from typing import Any, ClassVar
from tvm import IRModule
from tvm.target import Target
from tvm.tir.stmt_functor import post_order_visit
from tilelang import tvm as tvm
from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper
from tilelang.jit.adapter.utils import (match_declare_kernel, pythonic_expr,
parse_function_call_args, parse_tma_descriptor_args)
PREDEF_HOST_FUNC_PY = """
from cuda.bindings.driver import (
CUtensorMapDataType,
CUtensorMapInterleave,
CUtensorMapSwizzle,
CUtensorMapL2promotion,
CUtensorMapFloatOOBfill,
cuTensorMapEncodeTiled,
cuTensorMapEncodeIm2col,
CUresult,
cuKernelSetAttribute,
CUfunction_attribute,
CUdevice,
CUlaunchConfig,
cuLaunchKernelEx,
cuuint64_t,
cuuint32_t,
CUkernel,
)
import ctypes
_function_names = {}
def call({}):
{}
"""
TMA_DESC_INIT_FUNC_PY = """
{0}_type = CUtensorMapDataType({1})
{0}_tensorRank = {2}
{0}_globalAddress = {3}.data_ptr()
{0}_globalDim = [{4}]
{0}_globalStride = [{5}][1:]
{0}_boxDim = [{6}]
{0}_elementStrides = [{7}]
{0}_interleave = CUtensorMapInterleave({8})
{0}_swizzle = CUtensorMapSwizzle({9})
{0}_l2Promotion = CUtensorMapL2promotion({10})
{0}_oobFill = CUtensorMapFloatOOBfill({11})
res, {0} = cuTensorMapEncodeTiled(
{0}_type,
{0}_tensorRank,
{0}_globalAddress,
{0}_globalDim,
{0}_globalStride,
{0}_boxDim,
{0}_elementStrides,
{0}_interleave,
{0}_swizzle,
{0}_l2Promotion,
{0}_oobFill,
)
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}")
"""
TMA_IM2COL_DESC_INIT_FUNC_PY = """
{0}_type = CUtensorMapDataType({1})
{0}_tensorRank = {2}
{0}_globalAddress = {3}.data_ptr()
{0}_globalDim = [{4}]
{0}_globalStride = [{5}][1:]
{0}_elementStrides = [{6}]
{0}_lowerCorner = [{7}]
{0}_upperCorner = [{8}]
{0}_channelsPerPixel = {9}
{0}_pixelsPerColumn = {10}
{0}_interleave = CUtensorMapInterleave({11})
{0}_swizzle = CUtensorMapSwizzle({12})
{0}_l2Promotion = CUtensorMapL2promotion({13})
{0}_oobFill = CUtensorMapFloatOOBfill({14})
res, {0} = cuTensorMapEncodeIm2col(
{0}_type,
{0}_tensorRank,
{0}_globalAddress,
{0}_globalDim,
{0}_globalStride,
{0}_lowerCorner,
{0}_upperCorner,
{0}_channelsPerPixel,
{0}_pixelsPerColumn,
{0}_elementStrides,
{0}_interleave,
{0}_swizzle,
{0}_l2Promotion,
{0}_oobFill,
)
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}")
"""
L2_PERSISTENT_MAP_CREATE_HANDLE_PY = """
from cuda.bindings.driver import (
CUstreamAttrValue,
CUstreamAttrID,
CUlimit,
CUaccessProperty,
cuCtxGetLimit,
cuCtxSetLimit,
cuStreamSetAttribute,
cuCtxResetPersistingL2Cache,
)
stream_attribute = CUstreamAttrValue()
res, init_persisting_l2_cache_size = cuCtxGetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE)
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to get L2 cache size limit: {{res}}")
"""
L2_PERSISTENT_MAP_INIT_FUNC_PY = """
stream_attribute.accessPolicyWindow.hitRatio = {1}
stream_attribute.accessPolicyWindow.hitProp = CUaccessProperty.CU_ACCESS_PROPERTY_PERSISTING
stream_attribute.accessPolicyWindow.missProp = CUaccessProperty.CU_ACCESS_PROPERTY_STREAMING
res = cuCtxSetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE, {2})[0]
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to set L2 cache size limit: {{res}}")
stream_attribute.accessPolicyWindow.base_ptr = {0}.data_ptr()
stream_attribute.accessPolicyWindow.num_bytes = {2}
res = cuStreamSetAttribute(stream, CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW, stream_attribute)[0]
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to set stream L2 access policy: {{res}}")
"""
L2_PERSISTENT_MAP_RESET_HANDLE_PY = """
stream_attribute.accessPolicyWindow.num_bytes = 0
res = cuStreamSetAttribute(stream, CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW, stream_attribute)[0]
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to reset stream L2 access policy: {{res}}")
res = cuCtxResetPersistingL2Cache()[0]
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to reset L2 cache: {{res}}")
res = cuCtxSetLimit(CUlimit.CU_LIMIT_PERSISTING_L2_CACHE_SIZE, init_persisting_l2_cache_size)[0]
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to restore L2 cache size limit: {{res}}")
"""
KERNEL_LAUNCH_FUNC_PY = """
res = cuKernelSetAttribute(
CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
{7},
kernels["{0}"],
CUdevice({10})
)[0]
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}")
config = CUlaunchConfig()
config.gridDimX = {1}
config.gridDimY = {2}
config.gridDimZ = {3}
config.blockDimX = {4}
config.blockDimY = {5}
config.blockDimZ = {6}
config.sharedMemBytes = {7}
config.hStream = stream
arg_values = {8}
arg_types = {9}
res = cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0]
if res != CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to launch kernel {0}: {{res}}")
"""
class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"""NVRTC backend wrapper: generates Python kernel launch code.
Core responsibility: transform TVM IRModule into executable Python function
that initializes resources (TMA descriptors, L2 cache) and launches kernels
via CUDA Driver API.
Data flow:
IRModule → collect kernel metadata → deduplicate resources →
generate Python code → executable function
Why Python generation instead of C++:
NVRTC workflow requires runtime compilation, Python is the natural host.
Using cuda.bindings.driver eliminates C++ wrapper complexity.
"""
_TYPE_MAP: ClassVar[dict[str, str]] = {
"float32": "ctypes.c_float",
"float16": "ctypes.c_uint16",
"bfloat16": "ctypes.c_uint16",
"float8_e4m3": "ctypes.c_uint8",
"float8_e4m3fn": "ctypes.c_uint8",
"float8_e5m2": "ctypes.c_uint8",
"float64": "ctypes.c_double",
"int64": "ctypes.c_int64",
"int32": "ctypes.c_int32",
"uint32": "ctypes.c_uint32",
"bool": "ctypes.c_bool",
"int8": "ctypes.c_int8",
"uint8": "ctypes.c_uint8",
"int16": "ctypes.c_int16",
"uint16": "ctypes.c_uint16",
"uchar": "ctypes.c_uint8",
}
_generated_host_func: str | None = None
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
"""Initialize NVRTC wrapper with compiled IR modules.
Args:
scheduled_ir_module: TVM IR after scheduling passes
source: Generated CUDA C++ source code
target: Compilation target (should be NVRTC-compatible)
device_mod: Device-side IR module (kernel functions)
host_mod: Host-side IR module (launch logic)
pass_configs: Optional compiler pass configurations
"""
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
@property
def host_func(self):
"""Override parent's host_func to return generated Python code."""
if self._generated_host_func is not None:
return self._generated_host_func
return super().host_func
@host_func.setter
def host_func(self, value):
"""Allow setting generated host function code."""
self._generated_host_func = value
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
"""Convert TVM expression to Python string, ignoring casts.
Casts are noise in generated Python code - Python is dynamically typed.
"""
return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True)
def create_dispatch_func(self, code, function_informations):
"""Generate Python dispatch function that launches multiple CUDA kernels.
Why two-pass design:
Pass 1: Collect TMA descriptors from all kernels into shared dicts
Pass 2: Generate code - descriptors first (deduplicated), then launches
Single-pass would create duplicate descriptors for each kernel.
Dict naturally deduplicates by descriptor name.
Args:
code: CUDA C++ source containing kernel declarations
function_informations: Dict mapping kernel names to metadata
(grid/block dims, params, shared memory size)
Returns:
Python source code defining a call() function that:
1. Initializes L2 cache policies (if needed)
2. Creates TMA descriptors once per unique buffer
3. Launches each kernel with cuLaunchKernelEx
4. Resets L2 cache policies (if needed)
"""
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = [{"name": "kernels", "type": "dict[str, CUkernel]"}]
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.data.name,
"type": "ctypes.c_void_p",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
function_args.append(self.get_stream_type())
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['name']}" for arg in function_args])
# Check if any function needs L2 Persistent Map
has_l2_persistent_map = False
for function_name, _ in function_informations.items():
if function_name in self.l2_persistent_map:
has_l2_persistent_map = True
break
desc_name_map: dict[str, str] = {}
desc_name_var_map: dict[str, tvm.tir.Var] = {}
device_index = 0
kernel_launch_code = """"""
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE_PY
# First pass: collect all TMA descriptors from all kernels to avoid duplication
kernel_info_list = []
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
dynamic_smem_buf = function_info["dynamic_smem_buf"]
function_params = function_info["function_params"]
# Find the location of the global kernel function in the code
index = match_declare_kernel(code, function_name + "(")
# Analyze the function declaration to prepare for argument extraction
declaration = code[index:].split(";")[0]
# Identify the start of the function body to insert arguments
index = code.index("{", index)
# Transform function for NVRTC: returns (arg_value, arg_type) tuples
def transform_nvrtc_arg(name: str, arg_type: str):
if arg_type == "ctypes.c_void_p":
return (f"{name}.data_ptr()", arg_type)
return (name, arg_type)
call_args = parse_function_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map,
transform_nvrtc_arg)
for arg_name, arg_type in call_args:
if arg_type == "ctypes.c_void_p":
device_index = f"{arg_name.replace('.data_ptr()', '')}.device.index"
break
# Store kernel info for second pass
kernel_info_list.append({
'function_name': function_name,
'block_info': block_info,
'grid_info': grid_info,
'dynamic_smem_buf': dynamic_smem_buf,
'call_args': call_args,
'device_index': device_index,
})
# Generate TMA descriptor initialization code once for all kernels
kernel_launch_code += self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
# Second pass: generate kernel launch code for each kernel
for kernel_info in kernel_info_list:
function_name = kernel_info['function_name']
block_info = kernel_info['block_info']
grid_info = kernel_info['grid_info']
dynamic_smem_buf = kernel_info['dynamic_smem_buf']
call_args = kernel_info['call_args']
device_index = kernel_info['device_index']
arg_names = ", ".join([arg[0] for arg in call_args])
arg_types = ", ".join([arg[1] for arg in call_args])
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
# Generate L2 persistent map initialization for this function
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
# Generate kernel launch code
kernel_launch_code += KERNEL_LAUNCH_FUNC_PY.format(function_name,
self._pythonic_expr(grid_info[0]),
self._pythonic_expr(grid_info[1]),
self._pythonic_expr(grid_info[2]),
self._pythonic_expr(block_info[0]),
self._pythonic_expr(block_info[1]),
self._pythonic_expr(block_info[2]),
smem_str, arg_names, arg_types,
device_index)
# Reset L2 persistent map after all kernel execution
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE_PY
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC_PY.format(
repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func
def generate_l2_persistent_map(self, function_name: str) -> str:
"""Generate Python code to configure L2 cache persistence for a kernel.
L2 persistence pins frequently-accessed data in L2 cache to reduce
memory bandwidth. Requires explicit setup via CUDA stream attributes.
Args:
function_name: Kernel name to check for L2 persistence config
Returns:
Python code that sets stream access policy window, or empty
string if no L2 persistence configured for this kernel.
"""
if function_name not in self.l2_persistent_map:
return ""
init_l2_persistent_map = ""
for buffer_name, (hit_ratio,
size_in_bytes) in self.l2_persistent_map[function_name].items():
# Get persisting_l2_cache_max_size
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
try:
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)
except TypeError:
# as size_in_bytes may be a symbolic expression
num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC_PY.format(
buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str],
desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
"""Generate Python code to initialize TMA descriptors.
TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects
that describe memory layout for async copies. Must be created on host
before kernel launch.
Args:
desc_name_map: Maps descriptor variable names to buffer names
desc_name_var_map: Maps descriptor names to TVM variables
Returns:
Python code that calls cuTensorMapEncodeTiled/Im2col for each
unique descriptor. Empty string if no TMA descriptors needed.
"""
tma_descriptor_init = ""
if self.tma_descriptor_args is None:
return tma_descriptor_init
# Parse TMA descriptor arguments using the common utility
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map,
desc_name_var_map, self._pythonic_expr)
# Generate Python code from parsed parameters
for params in parsed_params:
if not params.is_img2col:
tma_descriptor_init += TMA_DESC_INIT_FUNC_PY.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address,
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)),
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)),
", ".join(map(lambda x: f"cuuint32_t({x})", params.box_dim)),
", ".join(map(lambda x: f"cuuint32_t({x})", params.element_strides)),
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill)
else:
tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format(
params.handle_name, params.dtype, params.tensor_rank, params.global_address,
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_dim)),
", ".join(map(lambda x: f"cuuint64_t({x})", params.global_stride)),
", ".join(map(lambda x: f"cuuint32_t({x})",
params.element_strides)), ", ".join(params.lower_corner),
", ".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel,
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill)
return tma_descriptor_init
def update_lib_code(self, code: str):
"""Update library code and generate host dispatch function.
Entry point for code generation. Walks the host IR to extract kernel
call sites, matches them with device kernels, then generates Python
dispatch code via create_dispatch_func().
Args:
code: CUDA C++ source code containing compiled kernels
Returns:
The same code string (stored in self.lib_code). Side effect:
sets self.host_func to generated Python dispatcher.
"""
# Update the library code with the given code string
self.lib_code = code
# Organize function information for code generation
function_informations = {}
for function_name in self.function_names:
# Do not update function with dispatch host function
if (function_name not in self.block_info) or (function_name not in self.grid_info):
continue
assert function_name in self.device_mod, f"Function {function_name} not found in device module"
device_func = self.device_mod[function_name]
kernel_params_cnt = len(device_func.params)
function_params: list[str] | None = None
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params
if isinstance(node, tvm.tir.Call):
if not (hasattr(node, "op") and
node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
return
args = node.args
if not args or args[0] != fn:
return
if len(args) < 1 + param_cnt:
raise AssertionError(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params = args[1:1 + param_cnt]
post_order_visit(self.host_func.body, visitor)
assert function_params is not None, "function_params should not be None"
function_informations[function_name] = {
"function_name": function_name,
"block_info": self.block_info[function_name],
"grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
"function_params": function_params,
}
# Create the host function wrapper for the CUDA kernel
self.host_func = self.create_dispatch_func(code, function_informations)
return self.lib_code
def get_stream_type(self) -> dict[str, str]:
"""Return stream parameter spec for Python signature.
NVRTC backend uses raw int for stream handle (not cudaStream_t pointer).
Default to 0 (NULL stream) for convenience.
"""
return {"name": "stream=0", "type": "int"}
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Literal from typing import Literal, Callable, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule, tir from tvm import IRModule, tir
from tvm.target import Target from tvm.target import Target
...@@ -107,13 +107,16 @@ def get_annotated_mod( ...@@ -107,13 +107,16 @@ def get_annotated_mod(
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str: def pythonic_expr(expr: tvm.tir.PrimExpr,
dtype_map: dict[str, str] | None = None,
ignore_cast: bool = False) -> str:
""" """
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
Args: Args:
expr: The TVM PrimExpr to convert. expr: The TVM PrimExpr to convert.
dtype_map: A dictionary mapping data types to their string representations.
ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast.
Returns: Returns:
A string representation of the expression. A string representation of the expression.
""" """
...@@ -158,10 +161,11 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non ...@@ -158,10 +161,11 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
elif isinstance(node, tvm.tir.Cast): elif isinstance(node, tvm.tir.Cast):
# C-style cast has high precedence # C-style cast has high precedence
value_str, _ = node_to_result_map[node.value] value_str, _ = node_to_result_map[node.value]
if dtype_map is None: if ignore_cast:
s = f"({node.dtype}){value_str}" s = value_str
else: else:
s = f"({dtype_map[node.dtype]}){value_str}" type_str = node.dtype if dtype_map is None else dtype_map[node.dtype]
s = f"({type_str}){value_str}"
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance( elif isinstance(
node, node,
...@@ -216,3 +220,238 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non ...@@ -216,3 +220,238 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non
tvm.tir.stmt_functor.post_order_visit(expr, _visitor) tvm.tir.stmt_functor.post_order_visit(expr, _visitor)
return next(iter(node_to_result_map[expr]), "") return next(iter(node_to_result_map[expr]), "")
def maybe_desc_name(name: str,
matches: list[str],
i: int,
desc_name_map: dict[str, str] | None = None) -> bool:
"""
Check if a parameter name corresponds to a TMA descriptor.
Args:
name: The parameter name to check.
matches: List of all matched parameter names.
i: Index of the current match.
desc_name_map: Optional mapping to store descriptor name relationships.
Returns:
True if the parameter is a TMA descriptor.
"""
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
desc_decls = []
if desc_name_map is not None:
desc_name_map[match] = name
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
def parse_function_call_args(
declaration: str,
function_args: list[dict[str, str]],
function_params: list[Any],
desc_name_map: dict[str, str] | None = None,
desc_name_var_map: dict[str, tvm.tir.Var] | None = None,
transform_arg: Callable[[str, str], Any] | None = None,
) -> list[Any]:
"""
Parse function call arguments from a kernel declaration.
Args:
declaration: The kernel function declaration string.
function_args: List of function argument specifications.
function_params: List of function parameters from TVM IR.
desc_name_map: Optional mapping for descriptor names.
desc_name_var_map: Optional mapping from descriptor names to TVM variables.
transform_arg: Optional function to transform each argument (name, type) -> result.
Returns:
List of parsed call arguments.
"""
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, declaration)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match:
if transform_arg is not None:
call_args.append(transform_arg(match, arg["type"]))
else:
call_args.append(match)
elif maybe_desc_name(arg["name"], matches, i, desc_name_map):
if transform_arg is not None:
call_args.append(transform_arg(match, "None"))
else:
call_args.append(match)
if desc_name_var_map is not None and function_params is not None:
assert len(call_args) <= len(function_params), \
f"Too many arguments: {len(call_args)} > {len(function_params)}"
desc_name_var_map[match] = function_params[len(call_args) - 1]
return call_args
class TMADescriptorParams:
"""Parsed TMA descriptor parameters."""
def __init__(self,
handle_name: str,
dtype: str,
tensor_rank: int,
global_address: Any,
is_img2col: bool = False):
self.handle_name = handle_name
self.dtype = dtype
self.tensor_rank = tensor_rank
self.global_address = global_address
self.is_img2col = is_img2col
# Common fields
self.global_dim: list[str] = []
self.global_stride: list[str] = []
self.element_strides: list[str] = []
self.interleave: str = ""
self.swizzle: str = ""
self.l2_promotion: str = ""
self.oob_fill: str = ""
# Tiled-specific fields
self.box_dim: list[str] = []
# Im2col-specific fields
self.lower_corner: list[str] = []
self.upper_corner: list[str] = []
self.smem_box_channel: str = ""
self.smem_box_pixel: str = ""
def parse_tma_descriptor_args(
tma_descriptor_args: dict[tvm.tir.Var, list[Any]],
desc_name_map: dict[str, str],
desc_name_var_map: dict[str, tvm.tir.Var],
pythonic_expr_func: Callable[[Any], str],
) -> list[TMADescriptorParams]:
"""
Parse TMA descriptor arguments into structured parameters.
Args:
tma_descriptor_args: Dictionary mapping TMA descriptor variables to their arguments.
desc_name_map: Mapping from descriptor handles to parameter names.
desc_name_var_map: Mapping from descriptor handles to TVM variables.
pythonic_expr_func: Function to convert TVM expressions to strings.
Returns:
List of parsed TMA descriptor parameters.
"""
if not tma_descriptor_args:
return []
results = []
for handle_name, _ in desc_name_map.items():
assert handle_name in desc_name_var_map, \
f"Handle name {handle_name} not found in desc_name_var_map"
desc_var = desc_name_var_map[handle_name]
assert desc_var in tma_descriptor_args, \
f"TMA descriptor {desc_var} not found in {tma_descriptor_args}"
args = tma_descriptor_args[desc_var]
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
tma_create_str, _, dtype, tensor_rank, global_address, *remaining_args = args
is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
# Convert basic fields
dtype = pythonic_expr_func(dtype)
tensor_rank = int(pythonic_expr_func(tensor_rank))
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
params = TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col)
if not is_img2col:
# Tiled mode
expected_args_len = 4 * tensor_rank + 4
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank]
]
params.box_dim = [
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.element_strides = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank]
]
# Extract remaining parameters
try:
interleave, swizzle, l2_promotion, oob_fill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
params.interleave = pythonic_expr_func(interleave)
params.swizzle = pythonic_expr_func(swizzle)
params.l2_promotion = pythonic_expr_func(l2_promotion)
params.oob_fill = pythonic_expr_func(oob_fill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
else:
# Im2col mode
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides
params.global_dim = [pythonic_expr_func(i) for i in remaining_args[:tensor_rank]]
params.global_stride = [
pythonic_expr_func(i) for i in remaining_args[tensor_rank:2 * tensor_rank]
]
params.element_strides = [
pythonic_expr_func(i) for i in remaining_args[2 * tensor_rank:3 * tensor_rank]
]
params.lower_corner = [
pythonic_expr_func(i) for i in remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
]
params.upper_corner = [
pythonic_expr_func(i)
for i in remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
]
# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2_promotion, oob_fill = \
remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2]
params.smem_box_pixel = pythonic_expr_func(smem_box_pixel)
params.smem_box_channel = pythonic_expr_func(smem_box_channel)
params.interleave = pythonic_expr_func(interleave)
params.swizzle = pythonic_expr_func(swizzle)
params.l2_promotion = pythonic_expr_func(l2_promotion)
params.oob_fill = pythonic_expr_func(oob_fill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 6 TMA parameters "
"(smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
) from e
results.append(params)
return results
...@@ -5,7 +5,8 @@ from typing import Any ...@@ -5,7 +5,8 @@ from typing import Any
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target,
is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr) is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr,
parse_function_call_args, parse_tma_descriptor_args)
import re import re
import logging import logging
import textwrap import textwrap
...@@ -49,16 +50,6 @@ extern "C" int call({}) {{ ...@@ -49,16 +50,6 @@ extern "C" int call({}) {{
}} }}
""" """
PREDEF_HOST_FUNC_PY = """
import cuda.bindings.driver
import ctypes
_function_names = {}
def call({}):
{}
"""
L2_PERSISTENT_MAP_CREATE_HANDLE = """ L2_PERSISTENT_MAP_CREATE_HANDLE = """
\tcudaStreamAttrValue stream_attribute; \tcudaStreamAttrValue stream_attribute;
\tsize_t init_persisting_l2_cache_size; \tsize_t init_persisting_l2_cache_size;
...@@ -136,65 +127,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """ ...@@ -136,65 +127,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """
\t}} \t}}
""" """
TMA_DESC_INIT_FUNC_PY = """
\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
\t{0}_tensorRank = {2}
\t{0}_globalAddress = {3}.data_ptr()
\t{0}_globalDim = [{4}]
\t{0}_globalStride = [{5}][1:]
\t{0}_boxDim = [{6}]
\t{0}_elementStrides = [{7}]
\t{0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({8})
\t{0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({9})
\t{0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({10})
\t{0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({11})
\tres, {0} = cuda.bindings.driver.cuTensorMapEncodeTiled(
\t\t{0}_type,
\t\t{0}_tensorRank,
\t\t{0}_globalAddress,
\t\t{0}_globalDim,
\t\t{0}_globalStride,
\t\t{0}_boxDim,
\t\t{0}_elementStrides,
\t\t{0}_interleave,
\t\t{0}_swizzle,
\t\t{0}_l2Promotion,
\t\t{0}_oobFill,
\t)
\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
\t\traise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}")
"""
KERNEL_LAUNCH_FUNC_PY = """
\tres = cuda.bindings.driver.cuKernelSetAttribute(
\t\tcuda.bindings.driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
\t\t{7},
\t\tkernels["{0}"],
\t\tcuda.bindings.driver.CUdevice({10})
\t)[0]
\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
\t\traise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}")
\tconfig = cuda.bindings.driver.CUlaunchConfig()
\tconfig.gridDimX = {1}
\tconfig.gridDimY = {2}
\tconfig.gridDimZ = {3}
\tconfig.blockDimX = {4}
\tconfig.blockDimY = {5}
\tconfig.blockDimZ = {6}
\tconfig.sharedMemBytes = {7}
\tconfig.hStream = stream
\targ_values = {8}
\targ_types = {9}
\tres = cuda.bindings.driver.cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0]
\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
\t\traise RuntimeError(f"Failed to launch kernel {0}: {{res}}")
"""
class BaseWrapper(ABC): class BaseWrapper(ABC):
...@@ -297,41 +229,6 @@ class TLCUDASourceWrapper: ...@@ -297,41 +229,6 @@ class TLCUDASourceWrapper:
# Format the function arguments for declaration # Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s,
function_args,
function_params,
desc_name_map: dict[str, str] | None = None,
desc_name_var_map: dict[str, tvm.tir.Var] | None = None):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
desc_decls = []
if desc_name_map is not None:
desc_name_map[match] = name
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
elif maybe_desc(arg["name"], matches, i):
call_args.append(match)
assert len(call_args) <= len(
function_params
), f"Function {function_name} has {len(function_params)} parameters, but {len(call_args)} arguments"
desc_name_var_map[match] = function_params[len(call_args) - 1]
return call_args
has_l2_persistent_map = False has_l2_persistent_map = False
for function_name, _ in function_informations.items(): for function_name, _ in function_informations.items():
if function_name in self.l2_persistent_map: if function_name in self.l2_persistent_map:
...@@ -365,8 +262,8 @@ class TLCUDASourceWrapper: ...@@ -365,8 +262,8 @@ class TLCUDASourceWrapper:
kernel_launch_code += init_l2_persistent_map kernel_launch_code += init_l2_persistent_map
if self.use_cooperative_groups[function_name]: if self.use_cooperative_groups[function_name]:
args_list = func_call_args(declaration, function_args, function_params, args_list = parse_function_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map) desc_name_map, desc_name_var_map)
assert len(function_params) == len( assert len(function_params) == len(
args_list args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
...@@ -377,8 +274,8 @@ class TLCUDASourceWrapper: ...@@ -377,8 +274,8 @@ class TLCUDASourceWrapper:
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format( kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
function_name, grid_str, block_str, function_name + "_args", smem_str) function_name, grid_str, block_str, function_name + "_args", smem_str)
else: else:
args_list = func_call_args(declaration, function_args, function_params, args_list = parse_function_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map) desc_name_map, desc_name_var_map)
assert len(function_params) == len( assert len(function_params) == len(
args_list args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
...@@ -420,101 +317,26 @@ class TLCUDASourceWrapper: ...@@ -420,101 +317,26 @@ class TLCUDASourceWrapper:
tma_descripter_init = "" tma_descripter_init = ""
if self.tma_descriptor_args is None: if self.tma_descriptor_args is None:
return tma_descripter_init return tma_descripter_init
for handle_name, _ in desc_name_map.items():
assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map"
desc_var = desc_name_var_map[handle_name]
assert desc_var in self.tma_descriptor_args, f"TMA descriptor {desc_var} not found in {self.tma_descriptor_args}"
args = self.tma_descriptor_args[desc_var]
# Skip __tvm_tensormap_create_tiled
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args
is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
dtype = self._pythonic_expr(dtype)
tensor_rank = int(self._pythonic_expr(tensor_rank))
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
if not is_img2col:
# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]
# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
# Parse TMA descriptor arguments using the common utility
parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map,
desc_name_var_map, self._pythonic_expr)
# Generate C++ code from parsed parameters
for params in parsed_params:
if not params.is_img2col:
tma_descripter_init += TMA_DESC_INIT_FUNC.format( tma_descripter_init += TMA_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), params.handle_name, params.dtype, params.tensor_rank, params.global_address,
",".join(global_stride), ",".join(box_dim), ",".join(element_strides), ",".join(params.global_dim), ",".join(params.global_stride),
interleave, swizzle, l2Promotion, oobFill) ",".join(params.box_dim), ",".join(params.element_strides), params.interleave,
params.swizzle, params.l2_promotion, params.oob_fill)
else: else:
# Calculate required length for remaining_args
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank]
lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
element_strides = [self._pythonic_expr(i) for i in element_strides]
lower_corner = [self._pythonic_expr(i) for i in lower_corner]
upper_corner = [self._pythonic_expr(i) for i in upper_corner]
# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[
5 * tensor_rank - 4:5 * tensor_rank + 2]
smem_box_pixel = self._pythonic_expr(smem_box_pixel)
smem_box_channel = self._pythonic_expr(smem_box_channel)
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
) from e
tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), params.handle_name, params.dtype, params.tensor_rank, params.global_address,
",".join(global_stride), ",".join(element_strides), ",".join(lower_corner), ",".join(params.global_dim), ",".join(params.global_stride),
",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle, ",".join(params.element_strides), ",".join(params.lower_corner),
l2Promotion, oobFill) ",".join(params.upper_corner), params.smem_box_channel, params.smem_box_pixel,
params.interleave, params.swizzle, params.l2_promotion, params.oob_fill)
return tma_descripter_init return tma_descripter_init
...@@ -713,213 +535,6 @@ class TLCUDASourceWrapper: ...@@ -713,213 +535,6 @@ class TLCUDASourceWrapper:
raise ValueError("Cannot find primary function in the module.") raise ValueError("Cannot find primary function in the module.")
class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"""
A wrapper class for the TileLang NVRTC backend.
"""
_TYPE_MAP = {
"float32": "ctypes.c_float",
"float16": "ctypes.c_uint16",
"bfloat16": "ctypes.c_uint16",
"float8_e4m3": "ctypes.c_uint8",
"float8_e4m3fn": "ctypes.c_uint8",
"float8_e5m2": "ctypes.c_uint8",
"float64": "ctypes.c_double",
"int64": "ctypes.c_int64",
"int32": "ctypes.c_int32",
"uint32": "ctypes.c_uint32",
"bool": "ctypes.c_bool",
"int8": "ctypes.c_int8",
"uint8": "ctypes.c_uint8",
"int16": "ctypes.c_int16",
"uint16": "ctypes.c_uint16",
"uchar": "ctypes.c_uint8",
}
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def create_dispatch_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = [{"name": "kernels", "type": "Dict[str, cuda.bindings.driver.CUkernel]"}]
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.data.name,
"type": "ctypes.c_void_p",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
function_args.append(self.get_stream_type())
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['name']}" for arg in function_args])
def func_call_args(s, function_args, desc_name_map: dict[str, str] | None = None):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
desc_decls = []
if desc_name_map is not None:
desc_name_map[match] = name
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match:
call_args.append(
(f"{match}.data_ptr()" if arg["type"] == "ctypes.c_void_p" else match,
arg["type"]))
elif maybe_desc(arg["name"], matches, i):
call_args.append((match, "None"))
return call_args
desc_name_map: dict[str, str] = {}
device_index = 0
kernel_launch_code = """"""
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
dynamic_smem_buf = function_info["dynamic_smem_buf"]
# Find the location of the global kernel function in the code
index = match_declare_kernel(code, function_name + "(")
# Analyze the function declaration to prepare for argument extraction
declaration = code[index:].split(";")[0]
# Identify the start of the function body to insert arguments
index = code.index("{", index)
call_args = func_call_args(declaration, function_args, desc_name_map)
for arg_name, arg_type in call_args:
if arg_type == "ctypes.c_void_p":
device_index = f"{arg_name.replace('.data_ptr()', '')}.device.index"
break
arg_names = ", ".join([arg[0] for arg in call_args])
arg_types = ", ".join([arg[1] for arg in call_args])
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += self.generate_tma_descriptor_args(
desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format(
function_name, self._pythonic_expr(grid_info[0]),
self._pythonic_expr(grid_info[1]), self._pythonic_expr(grid_info[2]),
self._pythonic_expr(block_info[0]), self._pythonic_expr(block_info[1]),
self._pythonic_expr(
block_info[2]), smem_str, arg_names, arg_types, device_index)
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC_PY.format(
repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
for handle_name, name in desc_name_map.items():
desc_name = name + "_desc"
assert desc_name in self.tma_descriptor_args, f"TMA descriptor {desc_name} not found in {self.tma_descriptor_args}"
args = self.tma_descriptor_args[desc_name]
# Skip __tvm_tensormap_create_tiled
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
tensor_rank = int(tensor_rank)
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
# Calculate required length for remaining_args
# 4 groups of tensor_rank size + 4 parameters
expected_args_len = 4 * tensor_rank + 4
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [str(i) for i in global_dim]
global_stride = [str(i) for i in global_stride]
box_dim = [str(i) for i in box_dim]
element_strides = [str(i) for i in element_strides]
# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
handle_name, dtype, tensor_rank, globalAddress,
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_dim)),
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_stride)),
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", box_dim)),
", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})",
element_strides)), interleave, swizzle, l2Promotion, oobFill)
return tma_descripter_init
def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Organize function information for code generation
function_informations = {}
for function_name in self.function_names:
# Do not update function with dispatch host function
if (function_name not in self.block_info) or (function_name not in self.grid_info):
continue
function_informations[function_name] = {
"function_name": function_name,
"block_info": self.block_info[function_name],
"grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
}
# Create the host function wrapper for the CUDA kernel
self.host_func = self.create_dispatch_func(code, function_informations)
return self.lib_code
def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=0", "type": "int"}
class TLHIPSourceWrapper(TLCUDASourceWrapper): class TLHIPSourceWrapper(TLCUDASourceWrapper):
""" """
A wrapper class for the TileLang HIP backend. A wrapper class for the TileLang HIP backend.
...@@ -1230,9 +845,10 @@ class TLPyWrapper(TLWrapper): ...@@ -1230,9 +845,10 @@ class TLPyWrapper(TLWrapper):
def wrap(self, c_source: str): def wrap(self, c_source: str):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first." # assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target): if is_cuda_target(self.target):
from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper
wrapper_class = TLNVRTCSourceWrapper wrapper_class = TLNVRTCSourceWrapper
else: else:
raise ValueError(f"Unsupported platform: {self.arch.platform}") raise ValueError(f"Unsupported target for NVRTC backend: {self.target}")
wrapper = wrapper_class( wrapper = wrapper_class(
scheduled_ir_module=self.scheduled_ir_module, scheduled_ir_module=self.scheduled_ir_module,
source=c_source, source=c_source,
......
...@@ -15,7 +15,7 @@ from tilelang import tvm ...@@ -15,7 +15,7 @@ from tilelang import tvm
from tilelang import env from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter) TorchDLPackKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc from tilelang.contrib import nvcc as tl_nvcc
...@@ -270,6 +270,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -270,6 +270,7 @@ class JITKernel(Generic[_P, _T]):
compile_flags=compile_flags, compile_flags=compile_flags,
) )
elif execution_backend == "nvrtc": elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter( adapter = NVRTCKernelAdapter(
params=artifact.params, params=artifact.params,
result_idx=out_idx, result_idx=out_idx,
...@@ -339,6 +340,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -339,6 +340,7 @@ class JITKernel(Generic[_P, _T]):
pass_configs=pass_configs, pass_configs=pass_configs,
) )
elif execution_backend == "nvrtc": elif execution_backend == "nvrtc":
from tilelang.jit.adapter import NVRTCKernelAdapter
adapter = NVRTCKernelAdapter.from_database( adapter = NVRTCKernelAdapter.from_database(
params=params, params=params,
result_idx=result_idx, result_idx=result_idx,
......
...@@ -5,6 +5,7 @@ from typing import Callable ...@@ -5,6 +5,7 @@ from typing import Callable
from tilelang.layout import Layout from tilelang.layout import Layout
from tvm.script.parser.tir import attr, block_attr from tvm.script.parser.tir import attr, block_attr
from tvm.tir import FloatImm
__all__ = [ __all__ = [
"use_swizzle", "use_swizzle",
...@@ -49,5 +50,5 @@ def annotate_l2_hit_ratio(l2_hit_ratio_map: dict): ...@@ -49,5 +50,5 @@ def annotate_l2_hit_ratio(l2_hit_ratio_map: dict):
_l2_hit_ratio_map = {} _l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items(): for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers" assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = float(hit_ratio) _l2_hit_ratio_map[buffer.data] = FloatImm("float32", float(hit_ratio))
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map}) return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment