Unverified Commit 74da3696 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Use tvm ffi as the default execution backend (#1259)

* [Refactor] Update FFI type handling and simplify argument management

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.

* [Update] Sync TVM submodule and enhance kernel source handling

* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging.
* Commented out the main execution call in test files to prevent unintended execution during testing.
* Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues.
* Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.

* [Refactor] Clean up imports and improve code formatting

* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code.
* Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency.
* Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.

* Update execution backend options and improve resolution logic

- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target.
- Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions.
- Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target.
- Updated documentation to reflect changes in execution backend options and their defaults.

* lint fix

* fix

* Enhance argument handling in CUDA and HIP runtime modules

- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime.
- Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers.
- Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.

* lint fix

* lint fix

* lint fix

* lint fix

* minor fix

* fix

* recover check

* Refactor argument binding and validation in `arg_binder.cc`

- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers.
- Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards.
- Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling.
- Minor adjustments in test files to streamline kernel execution and improve readability.

* lint fix

* stride fix

* minor fix

* fix

* lint fix

* lint fix

* Add CUDA stream access policy window helpers and integrate with L2 persistent cache management

- Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage.
- Updated runtime files to include new FFI packed functions for managing stream attributes.
- Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown.
- Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.

* check with symbolic

* support null ptr

* Update CMakeLists and lower.py for code generation and subproject status

- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support.
- Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility.
- Marked the TVM subproject as dirty to indicate local modifications.

* lint fix

* Update comments for clarity in quickstart.py
parent 921b96a3
...@@ -514,5 +514,4 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): ...@@ -514,5 +514,4 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() tilelang.testing.main()
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
...@@ -83,28 +83,27 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_ ...@@ -83,28 +83,27 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
func = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) kernel = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype)) d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype))
kernel(a, b, c, None, M, N, K, False)
func(a, b, c, None, M, N, K, False)
ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype)) ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
ref_with_bias = ref_no_bias + d ref_with_bias = ref_no_bias + d
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2) torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
func(a, b, c, d, M, N, K, True) kernel(a, b, c, d, M, N, K, True)
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2) torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
func = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
func(a, b, c, None, False) kernel(a, b, c, None, False)
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2) torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
func(a, b, c, d, True) kernel(a, b, c, d, True)
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2) torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
......
...@@ -90,7 +90,7 @@ def run_gemm( ...@@ -90,7 +90,7 @@ def run_gemm(
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes") matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi")
kernel_source = matmul_kernel.get_kernel_source() kernel_source = matmul_kernel.get_kernel_source()
...@@ -134,8 +134,6 @@ def matmu_jit_kernel( ...@@ -134,8 +134,6 @@ def matmu_jit_kernel(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -193,7 +191,7 @@ def run_gemm_jit_kernel( ...@@ -193,7 +191,7 @@ def run_gemm_jit_kernel(
num_threads, num_threads,
) )
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes") matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi")
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
...@@ -235,19 +233,19 @@ def test_gemm_jit_kernel(): ...@@ -235,19 +233,19 @@ def test_gemm_jit_kernel():
) )
def run_ctypes_kernel_do_bench(M, def run_tvm_ffi_kernel_do_bench(M,
N, N,
K, K,
trans_A, trans_A,
trans_B, trans_B,
in_dtype, in_dtype,
out_dtype, out_dtype,
dtypeAccum, dtypeAccum,
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=3,
num_threads=128): num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -264,14 +262,14 @@ def run_ctypes_kernel_do_bench(M, ...@@ -264,14 +262,14 @@ def run_ctypes_kernel_do_bench(M,
num_threads, num_threads,
) )
matmul_kernel = tilelang.compile(program, execution_backend="ctypes") matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi")
profiler = matmul_kernel.get_profiler() profiler = matmul_kernel.get_profiler()
ctypes_latency = profiler.do_bench(func=matmul_kernel) tvm_ffi_latency = profiler.do_bench(func=matmul_kernel)
print(f"Ctypes Latency: {ctypes_latency} ms") print(f"tvm_ffi Latency: {tvm_ffi_latency} ms")
assert ctypes_latency is not None assert tvm_ffi_latency is not None
tvm_latency = profiler.do_bench() tvm_latency = profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms") print(f"TVM Latency: {tvm_latency} ms")
...@@ -279,24 +277,24 @@ def run_ctypes_kernel_do_bench(M, ...@@ -279,24 +277,24 @@ def run_ctypes_kernel_do_bench(M,
assert tvm_latency is not None assert tvm_latency is not None
def test_ctypes_kernel_do_bench(): def test_tvm_ffi_kernel_do_bench():
run_ctypes_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2) 256, 32, 2)
def run_ctypes_kernel_multi_stream(M, def run_tvm_ffi_kernel_multi_stream(M,
N, N,
K, K,
trans_A, trans_A,
trans_B, trans_B,
in_dtype, in_dtype,
out_dtype, out_dtype,
dtypeAccum, dtypeAccum,
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=3,
num_threads=128): num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -313,7 +311,7 @@ def run_ctypes_kernel_multi_stream(M, ...@@ -313,7 +311,7 @@ def run_ctypes_kernel_multi_stream(M,
num_threads, num_threads,
) )
matmul_kernel = tilelang.compile(program, execution_backend="ctypes") matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi")
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype) out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
...@@ -332,24 +330,24 @@ def run_ctypes_kernel_multi_stream(M, ...@@ -332,24 +330,24 @@ def run_ctypes_kernel_multi_stream(M,
matmul_kernel(tensor_a, tensor_b, tensor_c) matmul_kernel(tensor_a, tensor_b, tensor_c)
def test_ctypes_kernel_multi_stream(): def test_tvm_ffi_kernel_multi_stream():
run_ctypes_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2) 128, 256, 32, 2)
def run_ctypes_dynamic_shape(M, def run_tvm_ffi_dynamic_shape(M,
N, N,
K, K,
trans_A, trans_A,
trans_B, trans_B,
in_dtype, in_dtype,
out_dtype, out_dtype,
dtypeAccum, dtypeAccum,
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=3,
num_threads=128): num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -366,7 +364,7 @@ def run_ctypes_dynamic_shape(M, ...@@ -366,7 +364,7 @@ def run_ctypes_dynamic_shape(M,
num_threads, num_threads,
) )
matmul_kernel = tilelang.compile(program, execution_backend="ctypes") matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi")
if isinstance(M, T.Var): if isinstance(M, T.Var):
M = 1024 M = 1024
if isinstance(N, T.Var): if isinstance(N, T.Var):
...@@ -393,19 +391,199 @@ def run_ctypes_dynamic_shape(M, ...@@ -393,19 +391,199 @@ def run_ctypes_dynamic_shape(M,
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_ctypes_dynamic_shape(): def test_tvm_ffi_dynamic_shape():
run_ctypes_dynamic_shape( run_tvm_ffi_dynamic_shape(
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_ctypes_dynamic_shape( run_tvm_ffi_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2) 256, 32, 2)
run_ctypes_dynamic_shape( run_tvm_ffi_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2) "float16", 128, 256, 32, 2)
def check_hopper():
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def convolution_im2col(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
def run_tvm_ffi_im2col_tma_desc(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256):
"""Test im2col TMA descriptor functionality in tvm_ffi backend."""
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages,
num_threads)
conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi")
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
out_c = conv_kernel(a, b)
# Reference implementation using torch.conv2d
def ref_program(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=S, padding=P, dilation=D)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
ref_c = ref_program(a, b)
tilelang.testing.torch_assert_close(
out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_tvm_ffi_im2col_tma_desc():
"""Test im2col TMA descriptor with tvm_ffi backend."""
if not check_hopper():
import pytest
pytest.skip("Test requires Hopper GPU (compute capability 9.0)")
# Small test case for im2col TMA descriptor
run_tvm_ffi_im2col_tma_desc(
N=4,
C=64,
H=32,
W=32,
F=64,
K=3,
S=1,
D=1,
P=1,
block_M=64,
block_N=128,
block_K=32,
num_stages=3,
num_threads=256)
def test_tvm_ffi_l2_persistent_map():
"""Test L2 persistent cache annotation with elementwise add."""
from tilelang.language import annotate_l2_hit_ratio
M = 1024
N = 1024
@tilelang.jit(out_idx=[-1], execution_backend="tvm_ffi")
def elementwise_add_with_l2_cache(
M,
N,
block_size=256,
dtype="float32",
):
@T.prim_func
def kernel(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(M * N // block_size, threads=block_size) as bx:
# Annotate L2 persistent cache for buffer B
# B will be accessed multiple times and benefit from L2 caching
annotate_l2_hit_ratio({B: 0.8})
for i in T.serial(block_size):
idx = bx * block_size + i
if idx < M * N:
row = idx // N
col = idx % N
C[row, col] = A[row, col] + B[row, col]
return kernel
# Compile the kernel
kernel = elementwise_add_with_l2_cache(M, N)
source = kernel.get_host_source()
assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
# Create test tensors
a = torch.randn(M, N, dtype=torch.float32).cuda()
b = torch.randn(M, N, dtype=torch.float32).cuda()
# Run kernel with out_idx=[-1], C is returned not passed in
c = kernel(a, b)
# Verify correctness
ref_c = a + b
tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)
print("L2 persistent map test passed!")
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() tilelang.testing.main()
test_gemm_f16f16f16_nn()
...@@ -113,7 +113,6 @@ def run_alloc_var_with_initializer( ...@@ -113,7 +113,6 @@ def run_alloc_var_with_initializer(
kernel = tilelang.compile(program, out_idx=[1]) kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source() code = kernel.get_kernel_source()
print(code)
assert f"= {init_value};" in code assert f"= {init_value};" in code
...@@ -151,8 +150,7 @@ def run_alloc_multi_vars_with_initializer( ...@@ -151,8 +150,7 @@ def run_alloc_multi_vars_with_initializer(
program = alloc_multi_vars_with_initializer(N, block_N, dtype) program = alloc_multi_vars_with_initializer(N, block_N, dtype)
kernel = tilelang.compile(program, out_idx=[1]) kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source() code = kernel.get_kernel_source(kernel_only=True)
print(code)
assert code.count("= 1;") == 1 assert code.count("= 1;") == 1
assert code.count("= 2;") == 1 assert code.count("= 2;") == 1
......
...@@ -33,7 +33,7 @@ class CompileArgs: ...@@ -33,7 +33,7 @@ class CompileArgs:
"""Compile arguments for the auto-tuner. Detailed description can be found in `tilelang.jit.compile`. """Compile arguments for the auto-tuner. Detailed description can be found in `tilelang.jit.compile`.
Attributes: Attributes:
out_idx: List of output tensor indices. out_idx: List of output tensor indices.
execution_backend: Execution backend to use for kernel execution (default: "cython"). execution_backend: Execution backend to use for kernel execution (default: "auto").
target: Compilation target, either as a string or a TVM Target object (default: "auto"). target: Compilation target, either as a string or a TVM Target object (default: "auto").
target_host: Target host for cross-compilation (default: None). target_host: Target host for cross-compilation (default: None).
verbose: Whether to enable verbose output (default: False). verbose: Whether to enable verbose output (default: False).
...@@ -42,7 +42,7 @@ class CompileArgs: ...@@ -42,7 +42,7 @@ class CompileArgs:
""" """
out_idx: list[int] | int | None = None out_idx: list[int] | int | None = None
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto"
target: Literal['auto', 'cuda', 'hip'] = 'auto' target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: str | Target = None target_host: str | Target = None
verbose: bool = False verbose: bool = False
...@@ -208,7 +208,7 @@ class AutotuneResult: ...@@ -208,7 +208,7 @@ class AutotuneResult:
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
pass_configs: dict = None, pass_configs: dict = None,
func: Callable = None, func: Callable = None,
verbose: bool = False, verbose: bool = False,
......
...@@ -139,8 +139,9 @@ class AutoTuner: ...@@ -139,8 +139,9 @@ class AutoTuner:
def set_compile_args(self, def set_compile_args(self,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto', target: Literal['auto', 'cuda', 'hip', 'metal'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
target_host: str | Target = None, target_host: str | Target = None,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None): pass_configs: dict[str, Any] | None = None):
...@@ -157,10 +158,15 @@ class AutoTuner: ...@@ -157,10 +158,15 @@ class AutoTuner:
Returns: Returns:
AutoTuner: Self for method chaining. AutoTuner: Self for method chaining.
""" """
# Normalize target to a concrete TVM Target and resolve execution backend
t = Target(determine_target(target))
from tilelang.jit.execution_backend import resolve_execution_backend
resolved_backend = resolve_execution_backend(execution_backend, t)
self.compile_args = CompileArgs( self.compile_args = CompileArgs(
out_idx=out_idx, out_idx=out_idx,
target=Target(determine_target(target)), target=t,
execution_backend=execution_backend, execution_backend=resolved_backend,
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs) pass_configs=pass_configs)
...@@ -591,7 +597,7 @@ class AutoTuner: ...@@ -591,7 +597,7 @@ class AutoTuner:
func=best_kernel.prim_func, func=best_kernel.prim_func,
kernel=best_kernel) kernel=best_kernel)
if self.compile_args.execution_backend in ("dlpack", "torch"): if self.compile_args.execution_backend in ("torch"):
logger.warning("DLPack backend does not support cache saving to disk.") logger.warning("DLPack backend does not support cache saving to disk.")
else: else:
with self._lock: with self._lock:
...@@ -728,8 +734,9 @@ def autotune( # This is the new public interface ...@@ -728,8 +734,9 @@ def autotune( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None. Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython"], optional execution_backend : Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
Backend for kernel execution and argument passing. Defaults to "cython". Backend for kernel execution and argument passing. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose : bool, optional verbose : bool, optional
Enables verbose logging during compilation. Defaults to False. Enables verbose logging during compilation. Defaults to False.
pass_configs : Optional[Dict[str, Any]], optional pass_configs : Optional[Dict[str, Any]], optional
......
...@@ -18,7 +18,8 @@ def cached( ...@@ -18,7 +18,8 @@ def cached(
*args, *args,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython", execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
| None = "auto",
verbose: bool | None = False, verbose: bool | None = False,
pass_configs: dict | None = None, pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
......
...@@ -13,14 +13,15 @@ from typing import Callable, Literal ...@@ -13,14 +13,15 @@ from typing import Callable, Literal
import cloudpickle import cloudpickle
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tvm.runtime import Executable
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang import env from tilelang import env
from tilelang.jit import JITKernel from tilelang.jit import JITKernel
from tilelang import __version__ from tilelang import __version__
KERNEL_PATH = "kernel.cu" DEVICE_KERNEL_PATH = "device_kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu" HOST_KERNEL_PATH = "host_kernel.cu"
EXECUTABLE_PATH = "executable.so"
KERNEL_LIB_PATH = "kernel_lib.so" KERNEL_LIB_PATH = "kernel_lib.so"
KERNEL_CUBIN_PATH = "kernel.cubin" KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py" KERNEL_PY_PATH = "kernel.py"
...@@ -40,7 +41,7 @@ class KernelCache: ...@@ -40,7 +41,7 @@ class KernelCache:
_instance = None # For implementing singleton pattern _instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython" execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi"
def __new__(cls): def __new__(cls):
""" """
...@@ -69,7 +70,7 @@ class KernelCache: ...@@ -69,7 +70,7 @@ class KernelCache:
self, self,
func: Callable, func: Callable,
out_idx: list[int], out_idx: list[int],
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
args=None, args=None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
...@@ -117,7 +118,8 @@ class KernelCache: ...@@ -117,7 +118,8 @@ class KernelCache:
*args, *args,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
...@@ -135,12 +137,30 @@ class KernelCache: ...@@ -135,12 +137,30 @@ class KernelCache:
Returns: Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache JITKernel: The compiled kernel, either freshly compiled or from cache
""" """
# Normalize target and resolve execution backend before proceeding
from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
norm_target = Target(_determine_target(target)) if isinstance(target, str) else target
requested_backend = execution_backend
execution_backend = resolve_execution_backend(requested_backend, norm_target)
if verbose:
allowed_now = allowed_backends_for_target(norm_target, include_unavailable=False)
# Avoid duplicate logs when caller already resolved explicitly
if requested_backend in (None, "auto") or requested_backend != execution_backend:
self.logger.info(
"Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)",
execution_backend,
requested_backend,
norm_target.kind.name,
", ".join(sorted(allowed_now)),
)
if not env.is_cache_enabled(): if not env.is_cache_enabled():
return JITKernel( return JITKernel(
func, func,
out_idx=out_idx, out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
target=target, target=norm_target,
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
...@@ -152,7 +172,7 @@ class KernelCache: ...@@ -152,7 +172,7 @@ class KernelCache:
out_idx=out_idx, out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
args=args, args=args,
target=target, target=norm_target,
target_host=target_host, target_host=target_host,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
...@@ -168,7 +188,7 @@ class KernelCache: ...@@ -168,7 +188,7 @@ class KernelCache:
self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}") self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}")
# Then check disk cache # Then check disk cache
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx, kernel = self._load_kernel_from_disk(key, norm_target, target_host, out_idx,
execution_backend, pass_configs, compile_flags, execution_backend, pass_configs, compile_flags,
func, verbose) func, verbose)
if kernel is not None: if kernel is not None:
...@@ -186,18 +206,15 @@ class KernelCache: ...@@ -186,18 +206,15 @@ class KernelCache:
func, func,
out_idx=out_idx, out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
target=target, target=norm_target,
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
) )
if execution_backend in ("dlpack", "torch"): with self._lock:
self.logger.warning("DLPack or torch backend does not support cache saving to disk.") if env.is_cache_enabled():
else: self._save_kernel_to_disk(key, kernel, func, verbose)
with self._lock:
if env.is_cache_enabled():
self._save_kernel_to_disk(key, kernel, func, verbose)
# Store in memory cache after compilation # Store in memory cache after compilation
self._memory_cache[key] = kernel self._memory_cache[key] = kernel
...@@ -239,6 +256,12 @@ class KernelCache: ...@@ -239,6 +256,12 @@ class KernelCache:
# Use atomic POSIX replace, so other processes cannot see a partial write # Use atomic POSIX replace, so other processes cannot see a partial write
os.replace(temp_path, path) os.replace(temp_path, path)
@staticmethod
def _safe_write_executable(executable: Executable, path: str):
temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}.so")
executable.export_library(temp_path)
os.replace(temp_path, path)
def _save_kernel_to_disk(self, def _save_kernel_to_disk(self,
key: str, key: str,
kernel: JITKernel, kernel: JITKernel,
...@@ -265,41 +288,46 @@ class KernelCache: ...@@ -265,41 +288,46 @@ class KernelCache:
# Save kernel source code # Save kernel source code
try: try:
kernel_path = os.path.join(cache_path, KERNEL_PATH) device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
if verbose: if verbose:
self.logger.debug(f"Saving kernel source code to file: {kernel_path}") self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None: if kernel.kernel_source is not None:
KernelCache._safe_write_file(kernel_path, "w", KernelCache._safe_write_file(device_kernel_path, "w",
lambda file: file.write(kernel.kernel_source)) lambda file: file.write(kernel.kernel_source))
except Exception as e: except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}") self.logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code # Save wrapped kernel source code
try: try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
if verbose: if verbose:
self.logger.debug( self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
f"Saving wrapped kernel source code to file: {wrapped_kernel_path}") if self.execution_backend == "tvm_ffi":
KernelCache._safe_write_file( KernelCache._safe_write_file(
wrapped_kernel_path, "w", host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_kernel_source())) lambda file: file.write(kernel.adapter.get_host_source()))
else:
KernelCache._safe_write_file(
host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_kernel_source()))
except Exception as e: except Exception as e:
self.logger.error(f"Error saving wrapped kernel source code to disk: {e}") self.logger.error(f"Error saving host kernel source code to disk: {e}")
# Save the kernel library # Save the kernel library
try: try:
# Save CUBIN or SO file # Save CUBIN or SO file
kernel_lib_path = KERNEL_CUBIN_PATH if self.execution_backend == "nvrtc" else KERNEL_LIB_PATH if self.execution_backend == "nvrtc":
kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH
else:
kernel_lib_path = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_path) kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
src_lib_path = kernel.adapter.libpath
if verbose:
self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
KernelCache._safe_write_file(
kernel_lib_path, "wb",
lambda file: file.write(KernelCache._load_binary(src_lib_path)))
# Save an extra Python file for NVRTC # Save an extra Python file for NVRTC
if self.execution_backend == "nvrtc": if self.execution_backend == "nvrtc":
src_lib_path = kernel.adapter.libpath
kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH)
src_lib_path = src_lib_path.replace(".cubin", ".py") src_lib_path = src_lib_path.replace(".cubin", ".py")
if verbose: if verbose:
...@@ -307,6 +335,19 @@ class KernelCache: ...@@ -307,6 +335,19 @@ class KernelCache:
KernelCache._safe_write_file( KernelCache._safe_write_file(
kernel_py_path, "wb", kernel_py_path, "wb",
lambda file: file.write(KernelCache._load_binary(src_lib_path))) lambda file: file.write(KernelCache._load_binary(src_lib_path)))
elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
self.logger.debug(f"Saving kernel executable to file: {executable}")
KernelCache._safe_write_executable(executable, kernel_lib_path)
else:
src_lib_path = kernel.adapter.libpath
if verbose:
self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
KernelCache._safe_write_file(
kernel_lib_path, "wb",
lambda file: file.write(KernelCache._load_binary(src_lib_path)))
except Exception as e: except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}") self.logger.error(f"Error saving kernel library to disk: {e}")
...@@ -326,7 +367,7 @@ class KernelCache: ...@@ -326,7 +367,7 @@ class KernelCache:
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
out_idx: list[int] = None, out_idx: list[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
func: Callable = None, func: Callable = None,
...@@ -349,25 +390,39 @@ class KernelCache: ...@@ -349,25 +390,39 @@ class KernelCache:
JITKernel: The loaded kernel if found, None otherwise. JITKernel: The loaded kernel if found, None otherwise.
""" """
cache_path = self._get_cache_path(key) cache_path = self._get_cache_path(key)
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
kernel_lib_path = os.path.join( host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
cache_path, KERNEL_CUBIN_PATH if self.execution_backend == "nvrtc" else KERNEL_LIB_PATH) if self.execution_backend == "nvrtc":
kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH
else:
kernel_lib_path = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
params_path = os.path.join(cache_path, PARAMS_PATH) params_path = os.path.join(cache_path, PARAMS_PATH)
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
return None return None
kernel_global_source: str | None = None device_kernel_source: str | None = None
host_kernel_source: str | None = None
kernel_params: list[KernelParam] | None = None kernel_params: list[KernelParam] | None = None
# Load the kernel source file (optional) # Load the kernel source file (optional)
try:
if verbose:
self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}")
with open(device_kernel_path) as f:
device_kernel_source = f.read()
except Exception as e:
self.logger.error(f"Error loading kernel source code from disk: {e}")
try: try:
if verbose: if verbose:
self.logger.debug( self.logger.debug(
f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(wrapped_kernel_path) as f: with open(host_kernel_path) as f:
kernel_global_source = f.read() host_kernel_source = f.read()
except Exception as e: except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}") self.logger.error(f"Error loading host kernel source code from disk: {e}")
# Load kernel parameters # Load kernel parameters
try: try:
...@@ -378,10 +433,11 @@ class KernelCache: ...@@ -378,10 +433,11 @@ class KernelCache:
except Exception as e: except Exception as e:
self.logger.error(f"Error loading kernel parameters from disk: {e}") self.logger.error(f"Error loading kernel parameters from disk: {e}")
if kernel_global_source and kernel_params: if host_kernel_source and device_kernel_source and kernel_params:
return JITKernel.from_database( return JITKernel.from_database(
func=func, func=func,
kernel_global_source=kernel_global_source, host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
params=kernel_params, params=kernel_params,
target=target, target=target,
...@@ -392,6 +448,7 @@ class KernelCache: ...@@ -392,6 +448,7 @@ class KernelCache:
compile_flags=compile_flags, compile_flags=compile_flags,
) )
else: else:
# TODO(lei): report what the reason is.
return None return None
def _clear_disk_cache(self): def _clear_disk_cache(self):
......
...@@ -59,23 +59,3 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): ...@@ -59,23 +59,3 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
return tvm_func(*args) return tvm_func(*args)
return _wrapper return _wrapper
def to_pytorch_func(tvm_func):
"""Convert a tvm function into one that accepts PyTorch tensors
Parameters
----------
tvm_func: Function
Built tvm function operating on arrays
Returns
-------
wrapped_func: Function
Wrapped tvm function that operates on PyTorch tensors
"""
# pylint: disable=import-outside-toplevel
import torch
import torch.utils.dlpack
return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)
...@@ -146,7 +146,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule: ...@@ -146,7 +146,7 @@ def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
if target_host.kind.name == "llvm": if target_host.kind.name == "llvm":
host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host) host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c": elif target_host.kind.name == "c":
host_mod = tvm.ffi.get_global_func("target.build.c")(host_mod, target_host) host_mod = tvm.ffi.get_global_func("target.build.tilelang_c")(host_mod, target_host)
else: else:
raise ValueError(f"Target host {target_host.kind.name} is not supported") raise ValueError(f"Target host {target_host.kind.name} is not supported")
return host_mod return host_mod
......
...@@ -23,7 +23,6 @@ except ImportError: # Python < 3.10 ...@@ -23,7 +23,6 @@ except ImportError: # Python < 3.10
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc from tilelang.language.v2 import PrimFunc
from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target from tvm.target import Target
from tilelang.jit.kernel import JITKernel from tilelang.jit.kernel import JITKernel
...@@ -46,7 +45,8 @@ _T = TypeVar('_T') ...@@ -46,7 +45,8 @@ _T = TypeVar('_T')
def compile( def compile(
func: PrimFunc[_KP, _T] = None, func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
...@@ -61,8 +61,9 @@ def compile( ...@@ -61,8 +61,9 @@ def compile(
The TileLang TIR function to compile and wrap. The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
Execution backend to use for kernel execution (default: "cython"). Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto"). Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
...@@ -80,8 +81,19 @@ def compile( ...@@ -80,8 +81,19 @@ def compile(
# This path is not a performance critical path, so we can afford to convert the target. # This path is not a performance critical path, so we can afford to convert the target.
target = Target(determine_target(target)) target = Target(determine_target(target))
if is_metal_target(target): # Resolve execution backend (handles aliases, auto, validation per target)
assert execution_backend == 'torch', 'Currently metal target only support `tl.jit(execution_backend="torch")`' requested_backend = execution_backend
from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
execution_backend = resolve_execution_backend(requested_backend, target)
if verbose:
allowed_now = allowed_backends_for_target(target, include_unavailable=False)
logger.info(
"Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)",
execution_backend,
requested_backend,
target.kind.name,
", ".join(sorted(allowed_now)),
)
return cached( return cached(
func=func, func=func,
...@@ -97,7 +109,8 @@ def compile( ...@@ -97,7 +109,8 @@ def compile(
def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target | None = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
...@@ -113,8 +126,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], ...@@ -113,8 +126,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
The TileLang TIR functions to compile and wrap. The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
Execution backend to use for kernel execution (default: "cython"). Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto"). Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
...@@ -165,7 +179,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], ...@@ -165,7 +179,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
class JITImpl(Generic[_P, _KP, _T]): class JITImpl(Generic[_P, _KP, _T]):
func: Callable[_P, _T] | PrimFunc[_KP, _T] func: Callable[_P, _T] | PrimFunc[_KP, _T]
out_idx: list[int] | int | None out_idx: list[int] | int | None
execution_backend: Literal["dlpack", "ctypes", "cython"] execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
target: str | Target target: str | Target
target_host: str | Target target_host: str | Target
verbose: bool verbose: bool
...@@ -286,7 +300,8 @@ def jit( ...@@ -286,7 +300,8 @@ def jit(
out_idx: Any = None, out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, debug_root_path: str | None = None,
...@@ -301,7 +316,8 @@ def jit( # This is the new public interface ...@@ -301,7 +316,8 @@ def jit( # This is the new public interface
out_idx: Any = None, out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, debug_root_path: str | None = None,
...@@ -322,8 +338,9 @@ def jit( # This is the new public interface ...@@ -322,8 +338,9 @@ def jit( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None. Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
Backend for kernel execution and argument passing. Defaults to "cython". Backend for kernel execution and argument passing. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose : bool, optional verbose : bool, optional
Enables verbose logging during compilation. Defaults to False. Enables verbose logging during compilation. Defaults to False.
pass_configs : Optional[Dict[str, Any]], optional pass_configs : Optional[Dict[str, Any]], optional
......
from .base import BaseKernelAdapter # noqa: F401 from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401 from .tvm_ffi import TVMFFIKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401 from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401
......
...@@ -4,6 +4,7 @@ from __future__ import annotations ...@@ -4,6 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable from typing import Any, Callable
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
import torch
class BaseKernelAdapter(ABC): class BaseKernelAdapter(ABC):
...@@ -46,11 +47,54 @@ class BaseKernelAdapter(ABC): ...@@ -46,11 +47,54 @@ class BaseKernelAdapter(ABC):
def _convert_torch_func(self) -> callable: def _convert_torch_func(self) -> callable:
pass pass
# --- Common helpers to align with PyTorch stream/device semantics ---
@staticmethod
def get_current_stream_functor() -> Callable[[], int]:
"""Return a callable that reads Torch's current CUDA stream pointer.
The returned lambda yields the raw CUDA stream handle of the current
PyTorch stream on the active device. It's a thunk (evaluated at call
time) so that any upstream stream guards are respected. If CUDA is
unavailable, it returns a lambda that yields 0.
"""
if torch.cuda.is_available():
try:
torch.cuda._lazy_init()
current_device = torch._C._cuda_getDevice
get_stream = torch._C._cuda_getCurrentRawStream
return lambda: get_stream(current_device())
except Exception:
# Fallback to Python API if internal handles are unavailable
return lambda: int(torch.cuda.current_stream().cuda_stream)
# CPU or CUDA unavailable: no stream semantics
return lambda: 0
@staticmethod
def get_current_device_functor() -> Callable[[], torch.device]:
"""Return a callable that yields Torch's current device.
Similar to the stream functor, we capture a callable that, when called,
fetches the current device according to PyTorch. On CPU or when CUDA is
unavailable, returns ``torch.device('cpu')``.
"""
if torch.cuda.is_available():
try:
torch.cuda._lazy_init()
current_device = torch._C._cuda_getDevice
return lambda: torch.device("cuda", current_device())
except Exception:
return lambda: torch.device("cuda", torch.cuda.current_device())
# CPU fallback
return lambda: torch.device("cpu")
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.func(*args, **kwds) return self.func(*args, **kwds)
def get_kernel_source(self) -> str: def get_kernel_source(self, kernel_only: bool = True) -> str:
return self.mod.imported_modules[0].get_source() if kernel_only:
return self.mod.imports[0].inspect_source()
else:
return self.mod.inspect_source() + "\n\n" + self.mod.imports[0].inspect_source()
def _post_init(self): def _post_init(self):
self.func = self._convert_torch_func() self.func = self._convert_torch_func()
...@@ -14,6 +14,7 @@ from tilelang.utils.target import determine_target ...@@ -14,6 +14,7 @@ from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
# TODO(lei): remove ctypes adapter.
class CtypesKernelAdapter(BaseKernelAdapter): class CtypesKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes. """Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes.
...@@ -28,9 +29,9 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -28,9 +29,9 @@ class CtypesKernelAdapter(BaseKernelAdapter):
ir_module: tvm.IRModule | None = None ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel # The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code # that is not wrapped by the wrapper code
kernel_global_source: str | None = None host_kernel_source: str | None = None
device_kernel_source: str | None = None
lib: ctypes.CDLL | None = None # Compiled library handle lib: ctypes.CDLL | None = None # Compiled library handle
wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Pass configs for the compiler # Pass configs for the compiler
...@@ -47,7 +48,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -47,7 +48,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
func_or_mod: tir.PrimFunc | tvm.IRModule, func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None, host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None, device_mod: tvm.IRModule | None = None,
kernel_global_source: str | None = None, host_kernel_source: str | None = None,
device_kernel_source: str | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None): compile_flags: list[str] | None = None):
...@@ -62,7 +64,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -62,7 +64,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
""" """
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source self.host_kernel_source = host_kernel_source
self.device_kernel_source = device_kernel_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -111,7 +114,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -111,7 +114,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: list[int], result_idx: list[int],
target: str, target: str,
func_or_mod: tir.PrimFunc | tvm.IRModule, func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str, host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
...@@ -119,8 +123,9 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -119,8 +123,9 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source adapter.host_kernel_source = host_kernel_source
adapter.wrapped_source = kernel_global_source adapter.device_kernel_source = device_kernel_source
adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source
adapter.pass_configs = pass_configs adapter.pass_configs = pass_configs
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
...@@ -288,7 +293,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -288,7 +293,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
def get_kernel_source(self, kernel_only: bool = False): def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel.""" """Returns the source code of the compiled kernel."""
if kernel_only: if kernel_only:
return self.kernel_global_source return self.device_kernel_source
else: else:
assert self.wrapped_source is not None, "Wrapped source is not available" # Wrapper only has host kernel source
return self.wrapped_source return self.host_kernel_source
...@@ -48,9 +48,9 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -48,9 +48,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
ir_module: tvm.IRModule | None = None ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel # The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code # that is not wrapped by the wrapper code
kernel_global_source: str | None = None host_kernel_source: str | None = None
device_kernel_source: str | None = None
lib: ctypes.CDLL | None = None # Compiled library handle lib: ctypes.CDLL | None = None # Compiled library handle
wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Maps pointer arguments to their corresponding (buffer_index, shape_dimension) # Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
...@@ -77,7 +77,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -77,7 +77,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
func_or_mod: tir.PrimFunc | tvm.IRModule, func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None, host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None, device_mod: tvm.IRModule | None = None,
kernel_global_source: str | None = None, device_kernel_source: str | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None): compile_flags: list[str] | None = None):
...@@ -92,7 +92,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -92,7 +92,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
""" """
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source self.device_kernel_source = device_kernel_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -121,9 +121,9 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -121,9 +121,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod) self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod) self.wrapper.assign_device_module(device_mod)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) self.host_kernel_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source) self.lib_generator.update_lib_code(self.host_kernel_source)
self.lib_generator.compile_lib() self.lib_generator.compile_lib()
self.lib = self.lib_generator.load_lib() self.lib = self.lib_generator.load_lib()
...@@ -150,7 +150,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -150,7 +150,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
result_idx: list[int], result_idx: list[int],
target: str, target: str,
func_or_mod: tir.PrimFunc | tvm.IRModule, func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str, host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
...@@ -158,8 +159,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -158,8 +159,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source adapter.host_kernel_source = host_kernel_source
adapter.wrapped_source = kernel_global_source adapter.device_kernel_source = device_kernel_source
adapter.pass_configs = pass_configs adapter.pass_configs = pass_configs
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
...@@ -382,7 +383,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -382,7 +383,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
def get_kernel_source(self, kernel_only: bool = False): def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel.""" """Returns the source code of the compiled kernel."""
if kernel_only: if kernel_only:
return self.kernel_global_source return self.device_kernel_source
else: else:
assert self.wrapped_source is not None, "Wrapped source is not available" # Wrapper only has host kernel source
return self.wrapped_source assert self.host_kernel_source is not None, "Wrapped source is not available"
return self.host_kernel_source
"""The profiler and convert to torch utils"""
import torch
from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter
class TorchDLPackKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> callable:
torch_func = to_pytorch_func(self.mod)
def func(*ins: list[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
ins_idx = 0
args = []
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
for i in range(len(self.params)):
if i in self.result_idx:
dtype = self.params[i].dtype
shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
torch_func(*args)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
return func
...@@ -34,7 +34,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -34,7 +34,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
func_or_mod: tir.PrimFunc | tvm.IRModule, func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None, host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None, device_mod: tvm.IRModule | None = None,
kernel_global_source: str | None = None, device_kernel_source: str | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None): compile_flags: list[str] | None = None):
...@@ -43,7 +43,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -43,7 +43,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source self.device_kernel_source = device_kernel_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -74,10 +74,10 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -74,10 +74,10 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod) self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod) self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source) self.host_func, self.function_names = self.wrapper.wrap(device_kernel_source)
self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose) self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.kernel_global_source) self.lib_generator.update_lib_code(self.device_kernel_source)
self.lib_generator.update_host_func(self.host_func) self.lib_generator.update_host_func(self.host_func)
self.lib_generator.assign_compile_flags(compile_flags) self.lib_generator.assign_compile_flags(compile_flags)
self.lib_generator.compile_lib() self.lib_generator.compile_lib()
...@@ -97,7 +97,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -97,7 +97,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
result_idx: list[int], result_idx: list[int],
target: str, target: str,
func_or_mod: tir.PrimFunc | tvm.IRModule, func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str, host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
...@@ -105,7 +106,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -105,7 +106,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source adapter.host_kernel_source = host_kernel_source
adapter.device_kernel_source = device_kernel_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -167,7 +169,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -167,7 +169,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[shape] = (i, j) dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def get_kernel_source(self) -> str | None: def get_kernel_source(self, kernel_only: bool = True) -> str | None:
"""Get the CUDA kernel source code. """Get the CUDA kernel source code.
Returns Returns
...@@ -175,7 +177,10 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -175,7 +177,10 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
Optional[str] Optional[str]
The kernel source code, or None if not available The kernel source code, or None if not available
""" """
return self.kernel_global_source if kernel_only:
return self.device_kernel_source
else:
return self.host_func
def _forward_from_prebuild_lib(self, *args, stream: int | None = None): def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
......
"""Utilities to adapt TVM FFI kernels to Torch tensors.
This adapter intentionally captures PyTorch's current CUDA stream and device
via light-weight callables so that, when the wrapped function is invoked,
the execution observes the same stream context as the active Torch code.
On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
"""
from __future__ import annotations
from typing import Callable, Any
import torch
from tilelang import tvm
from tvm import runtime, tir
from tvm.target import Target
from tvm.relax import TensorType
from tilelang.utils.target import determine_target
from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.utils.language import retrieve_func_from_module
from tilelang.engine.param import KernelParam
class TVMFFIKernelAdapter(BaseKernelAdapter):
"""Adapter that runs a TVM runtime.Executable with Torch tensors.
Notes
- We capture the "current" PyTorch CUDA stream/device as thunks (callables)
rather than materializing them at construction time. This ensures the
actual stream/device is read just-in-time when the function runs, matching
the user's current Torch context (e.g., after a stream guard/switch).
- The stream pointer returned is a raw CUDA stream handle compatible with
TVM's device API; on CPU or when CUDA is unavailable, we return 0.
"""
# Class attributes to store compiled kernel information
target: str | Target = "cuda"
ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
host_kernel_source: str | None = None
device_kernel_source: str | None = None
executable: tvm.runtime.Executable | None = None
# Pass configs for the compiler
pass_configs: dict[str, Any] | None = None
# host_mod
host_mod: tvm.IRModule | None = None
# device_mod
device_mod: tvm.IRModule | None = None
# rt_mod
rt_mod: tvm.runtime.Module | None = None
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] | None = None
# Stream/device functors are inherited from BaseKernelAdapter
def __init__(self,
params: list[KernelParam],
result_idx: list[int],
target: str | Target,
func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None,
rt_mod: tvm.runtime.Module | None = None,
host_kernel_source: str | None = None,
device_kernel_source: str | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
"""Initialize the adapter with the given TIR function or module.
Args:
params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors
target: Target platform (e.g., 'cuda')
func_or_mod: TIR function or module to be compiled
verbose: Enable verbose logging
"""
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self.host_kernel_source = host_kernel_source
self.device_kernel_source = device_kernel_source
if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
self.ir_module = func_or_mod
self.target = Target.canon_target(determine_target(target))
self.host_mod = host_mod
self.device_mod = device_mod
self.rt_mod = rt_mod
self.verbose = verbose
self.pass_configs = pass_configs
self.compile_flags = compile_flags
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self._post_init()
def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension)
for runtime shape resolution.
id represents shape or stride, 0 represents shape, 1 represents stride
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
dynamic_symbolic_map = {}
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
def _convert_torch_func(self) -> Callable[..., Any]:
# Capture thunks that reflect Torch's current stream and device.
# These are evaluated at call time to align TVM execution with the
# caller's active PyTorch stream/device.
# current_stream_functor = self.get_current_stream_functor()
current_device_functor = self.get_current_device_functor()
# Convert TVM types to native Python types during initialization
param_dtypes = [param.dtype for param in self.params]
# Convert TVM shape arrays to native Python lists
param_shapes = []
for param in self.params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
native_shape.append(dim) # Keep tir.Var for dynamic dimensions
else:
native_shape.append(dim)
param_shapes.append(native_shape)
if self.executable is None:
self.executable = runtime.Executable(self.rt_mod)
dynamic_symbolic_map = self._process_dynamic_symbolic()
executable = self.executable
# Prepare helpers for friendly dtype error messages
prim_func = self.prim_func
buffer_map = prim_func.buffer_map
params = prim_func.params
# Expected dtype string per parameter index (for buffers only)
expected_dtype_strs: list[str | None] = []
# Track whether each param is a buffer (has dtype) vs scalar
is_buffer_param: list[bool] = []
for p in params:
if p in buffer_map:
expected_dtype_strs.append(str(buffer_map[p].dtype))
is_buffer_param.append(True)
else:
expected_dtype_strs.append(None)
is_buffer_param.append(False)
# Global function name used in error messages
global_symbol = str(prim_func.attrs.get("global_symbol", "main"))
# Map torch dtype to TVM-style dtype string
def torch_dtype_to_tvm_str(dtype: torch.dtype) -> str:
try:
import torch as _torch
except Exception: # pragma: no cover
# Fallback, though torch should always be available here
return str(dtype)
fp8_e4m3fn = getattr(_torch, "float8_e4m3fn", None)
fp8_e4m3fnuz = getattr(_torch, "float8_e4m3fnuz", None)
fp8_e5m2 = getattr(_torch, "float8_e5m2", None)
fp8_e5m2fnuz = getattr(_torch, "float8_e5m2fnuz", None)
if fp8_e4m3fn is not None and dtype == fp8_e4m3fn:
return "float8_e4m3"
if fp8_e4m3fnuz is not None and dtype == fp8_e4m3fnuz:
return "float8_e4m3fnuz"
if fp8_e5m2 is not None and dtype == fp8_e5m2:
return "float8_e5m2"
if fp8_e5m2fnuz is not None and dtype == fp8_e5m2fnuz:
return "float8_e5m2"
# Strip torch. prefix for readability
s = str(dtype)
return s[6:] if s.startswith("torch.") else s
def func(*inputs: torch.Tensor | Any):
# Validate input count strictly
expected_inputs = len(self.params) - len(self.result_idx)
if len(inputs) != expected_inputs:
raise ValueError(
f"Expected {expected_inputs} inputs, got {len(inputs)} (params={len(self.params)}, outputs={len(self.result_idx)})"
)
# Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device.
out_device: torch.device | None = None
# Stitch the full positional argument list expected by the TVM executable
ins_idx: int = 0
tensor_list: list[torch.Tensor] = []
# Prepare input and output tensors
for i in range(len(self.params)):
if i in self.result_idx:
dtype = param_dtypes[i]
shape = []
# Now working with native Python list, no FFI calls needed
for s in param_shapes[i]:
if isinstance(s, tir.Var):
for key in dynamic_symbolic_map:
if (str(s) == str(key)):
ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[
key]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
else: # Already converted to Python int during initialization
shape.append(s)
if out_device is None:
out_device = current_device_functor()
if len(shape) == 0:
param_name = self.params[i].name if hasattr(self.params[i],
'name') else f'parameter_{i}'
raise ValueError(
f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. "
f"Expected shape: {shape}")
tensor = torch.empty(*shape, dtype=dtype, device=out_device)
else:
tensor = inputs[ins_idx]
# Input dtype validation with clear error message
if is_buffer_param[i]:
expected_dtype_str = expected_dtype_strs[i]
expected_torch_dtype = param_dtypes[i]
# Only check when the argument is a tensor and expected dtype is known
if isinstance(
tensor, torch.Tensor
) and expected_dtype_str is not None and tensor.dtype != expected_torch_dtype:
param_var = params[i]
# Reconstruct TVM-like handle name A_handle for error clarity
handle_name = f"{param_var.name}_handle"
actual_dtype_str = torch_dtype_to_tvm_str(tensor.dtype)
raise RuntimeError(
f"{global_symbol}.{handle_name}.dtype is expected to be {expected_dtype_str}, but got {actual_dtype_str}"
)
ins_idx += 1
tensor_list.append(tensor)
executable(*tensor_list)
# Return outputs in the requested form
if len(self.result_idx) == 1:
return tensor_list[self.result_idx[0]]
return [tensor_list[i] for i in self.result_idx]
return func
@classmethod
def from_database(cls,
params: list[TensorType],
result_idx: list[int],
target: str,
func_or_mod: tir.PrimFunc | tvm.IRModule,
host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.host_kernel_source = host_kernel_source
adapter.device_kernel_source = device_kernel_source
adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source
adapter.pass_configs = pass_configs
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
adapter.ir_module = func_or_mod
target = determine_target(target, return_object=True)
adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose
adapter.executable = runtime.load_module(kernel_lib_path)
adapter._post_init()
return adapter
def get_host_source(self):
"""Returns the source code of the host module."""
if self.host_kernel_source is not None:
return self.host_kernel_source
return self.rt_mod.inspect_source()
def get_device_source(self):
"""Returns the source code of the device module."""
if self.device_kernel_source is not None:
return self.device_kernel_source
return self.rt_mod.imports[0].inspect_source()
def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel."""
if kernel_only:
return self.get_device_source()
else:
return self.get_device_source() + "\n\n" + self.get_host_source()
@property
def prim_func(self) -> tir.PrimFunc:
"""Returns the primary TIR function from the IR module."""
return retrieve_func_from_module(self.ir_module)
from __future__ import annotations
from collections.abc import Iterable
from tvm.target import Target
# Canonical names for execution backends used internally
_CANONICAL_MAP = {
"dlpack": "tvm_ffi", # historical alias
}
def _canon_backend(name: str | None) -> str | None:
if name is None:
return None
key = str(name).lower()
return _CANONICAL_MAP.get(key, key)
def _target_kind(target: Target) -> str:
# tvm.target.Target always has kind.name
return target.kind.name
def allowed_backends_for_target(target: Target, *, include_unavailable: bool = True) -> list[str]:
"""Return allowed execution backends for a given TVM target kind.
include_unavailable: if False, this will filter out backends that are known
to be unavailable at runtime (e.g., NVRTC without cuda-python installed).
"""
kind = _target_kind(target)
if kind == "cuda":
allowed = ["tvm_ffi", "nvrtc", "cython", "ctypes"]
elif kind == "hip":
allowed = ["tvm_ffi", "cython", "ctypes"]
elif kind == "metal":
allowed = ["torch"]
elif kind == "c": # CPU C backend
allowed = ["cython", "ctypes", "tvm_ffi"]
else:
# Fallback: prefer portable hosts
allowed = ["cython", "ctypes", "tvm_ffi"]
if not include_unavailable:
# Drop NVRTC if not importable
try:
from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy
if not is_nvrtc_available and "nvrtc" in allowed:
allowed = [b for b in allowed if b != "nvrtc"]
except Exception:
# Be conservative and keep nvrtc if detection itself fails
pass
return allowed
def _format_options(options: Iterable[str]) -> str:
return ", ".join(sorted(options))
def resolve_execution_backend(requested: str | None, target: Target) -> str:
"""Resolve an execution backend string to a concrete backend.
- Supports the alias "dlpack" -> "tvm_ffi".
- Supports the sentinel "auto" which selects a sensible default per target.
- Validates the combination (target, backend) and raises with helpful
alternatives when invalid.
"""
req = _canon_backend(requested)
allowed_all = allowed_backends_for_target(target, include_unavailable=True)
allowed_avail = allowed_backends_for_target(target, include_unavailable=False)
# Default selection for auto/None
if req in (None, "auto"):
kind = _target_kind(target)
if kind == "cuda":
choice = "tvm_ffi"
elif kind == "metal":
choice = "torch"
else:
choice = "cython"
# If the chosen default is not available (very rare), fall back to first available
if choice not in allowed_avail and allowed_avail:
choice = allowed_avail[0]
return choice
# Validate against allowed
if req not in allowed_all:
raise ValueError(
f"Invalid execution backend '{requested}' for target '{_target_kind(target)}'. "
f"Allowed: {_format_options(allowed_all)}. Tip: use execution_backend='auto'.")
# Promote to availability-aware set for nicer errors (e.g., nvrtc not installed)
if req not in allowed_avail:
raise ValueError(
f"Execution backend '{requested}' requires extra dependencies and is not available now. "
f"Try one of: {_format_options(allowed_avail)}.")
return req
...@@ -15,7 +15,7 @@ from tilelang import tvm ...@@ -15,7 +15,7 @@ from tilelang import tvm
from tilelang import env from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
TorchDLPackKernelAdapter, MetalKernelAdapter) TVMFFIKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc from tilelang.contrib import nvcc as tl_nvcc
...@@ -55,7 +55,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -55,7 +55,7 @@ class JITKernel(Generic[_P, _T]):
self, self,
func: PrimFunc = None, func: PrimFunc = None,
out_idx: list[int] | int = None, out_idx: list[int] | int = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
verbose: bool = False, verbose: bool = False,
...@@ -72,8 +72,8 @@ class JITKernel(Generic[_P, _T]): ...@@ -72,8 +72,8 @@ class JITKernel(Generic[_P, _T]):
The TileLang TIR function to compile and wrap. The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional execution_backend : Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
Execution backend to use for kernel execution (default: "cython"). Execution backend to use for kernel execution.
target : Union[str, Target], optional target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto"). Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
...@@ -102,7 +102,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -102,7 +102,7 @@ class JITKernel(Generic[_P, _T]):
# Validate the execution backend. # Validate the execution backend.
assert execution_backend in [ assert execution_backend in [
"dlpack", "tvm_ffi",
"ctypes", "ctypes",
"cython", "cython",
"nvrtc", "nvrtc",
...@@ -143,13 +143,14 @@ class JITKernel(Generic[_P, _T]): ...@@ -143,13 +143,14 @@ class JITKernel(Generic[_P, _T]):
def from_database( def from_database(
cls, cls,
func: PrimFunc, func: PrimFunc,
kernel_global_source: str, host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str, kernel_lib_path: str,
params: list[KernelParam], params: list[KernelParam],
target: str | Target, target: str | Target,
target_host: str | Target, target_host: str | Target,
out_idx: list[int] | int, out_idx: list[int] | int,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"], execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"],
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None, compile_flags: list[str] | None = None,
): ):
...@@ -172,7 +173,8 @@ class JITKernel(Generic[_P, _T]): ...@@ -172,7 +173,8 @@ class JITKernel(Generic[_P, _T]):
params=params, params=params,
result_idx=out_idx, result_idx=out_idx,
target=target, target=target,
kernel_global_source=kernel_global_source, host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
...@@ -223,8 +225,8 @@ class JITKernel(Generic[_P, _T]): ...@@ -223,8 +225,8 @@ class JITKernel(Generic[_P, _T]):
compile_flags = self.compile_flags compile_flags = self.compile_flags
# Compile the function with TVM, optimizing with shared memory lowering. # Compile the function with TVM, optimizing with shared memory lowering.
enable_host_codegen = execution_backend == "dlpack" enable_host_codegen = execution_backend == "tvm_ffi"
enable_device_compile = execution_backend == "dlpack" enable_device_compile = execution_backend == "tvm_ffi"
with tvm.transform.PassContext(opt_level=3, config=pass_configs), self.target: with tvm.transform.PassContext(opt_level=3, config=pass_configs), self.target:
artifact = tilelang.lower( artifact = tilelang.lower(
tilelang_func, tilelang_func,
...@@ -236,13 +238,23 @@ class JITKernel(Generic[_P, _T]): ...@@ -236,13 +238,23 @@ class JITKernel(Generic[_P, _T]):
self.artifact = artifact self.artifact = artifact
# Create an adapter based on the specified execution backend. # Create an adapter based on the specified execution backend.
if execution_backend == "dlpack": if execution_backend == "tvm_ffi":
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack. # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
# But we need to ensure that the runtime is enabled and the runtime module is not None. # But we need to ensure that the runtime is enabled and the runtime module is not None.
assert tvm.runtime.enabled("llvm"), "DLPack backend requires LLVM runtime." assert (artifact.rt_mod is not None), "tvm_ffi backend requires a runtime module."
assert (artifact.rt_mod is not None), "DLPack backend requires a runtime module." adapter = TVMFFIKernelAdapter(
adapter = TorchDLPackKernelAdapter( params=artifact.params,
artifact.rt_mod, params=artifact.params, result_idx=out_idx) result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
rt_mod=artifact.rt_mod,
device_kernel_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
elif execution_backend == "ctypes": elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter( adapter = CtypesKernelAdapter(
params=artifact.params, params=artifact.params,
...@@ -251,7 +263,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -251,7 +263,7 @@ class JITKernel(Generic[_P, _T]):
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
host_mod=artifact.host_mod, host_mod=artifact.host_mod,
device_mod=artifact.device_mod, device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source, device_kernel_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
...@@ -264,7 +276,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -264,7 +276,7 @@ class JITKernel(Generic[_P, _T]):
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
host_mod=artifact.host_mod, host_mod=artifact.host_mod,
device_mod=artifact.device_mod, device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source, device_kernel_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
...@@ -278,7 +290,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -278,7 +290,7 @@ class JITKernel(Generic[_P, _T]):
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
host_mod=artifact.host_mod, host_mod=artifact.host_mod,
device_mod=artifact.device_mod, device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source, device_kernel_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
...@@ -308,7 +320,8 @@ class JITKernel(Generic[_P, _T]): ...@@ -308,7 +320,8 @@ class JITKernel(Generic[_P, _T]):
result_idx: list[int] | int, result_idx: list[int] | int,
target: str | Target, target: str | Target,
func_or_mod: PrimFunc | tvm.runtime.Module, func_or_mod: PrimFunc | tvm.runtime.Module,
kernel_global_source: str, host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str, kernel_lib_path: str,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None) -> BaseKernelAdapter: compile_flags: list[str] | None = None) -> BaseKernelAdapter:
...@@ -316,15 +329,26 @@ class JITKernel(Generic[_P, _T]): ...@@ -316,15 +329,26 @@ class JITKernel(Generic[_P, _T]):
execution_backend = self.execution_backend execution_backend = self.execution_backend
# Create an adapter based on the specified execution backend. # Create an adapter based on the specified execution backend.
if execution_backend == "dlpack": if execution_backend == "tvm_ffi":
raise ValueError("DLPack backend is not supported for TileLang JIT.") adapter = TVMFFIKernelAdapter.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
elif execution_backend == "ctypes": elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter.from_database( adapter = CtypesKernelAdapter.from_database(
params=params, params=params,
result_idx=result_idx, result_idx=result_idx,
target=target, target=target,
func_or_mod=func_or_mod, func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source, host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
...@@ -335,7 +359,8 @@ class JITKernel(Generic[_P, _T]): ...@@ -335,7 +359,8 @@ class JITKernel(Generic[_P, _T]):
result_idx=result_idx, result_idx=result_idx,
target=target, target=target,
func_or_mod=func_or_mod, func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source, host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs, pass_configs=pass_configs,
) )
...@@ -346,7 +371,8 @@ class JITKernel(Generic[_P, _T]): ...@@ -346,7 +371,8 @@ class JITKernel(Generic[_P, _T]):
result_idx=result_idx, result_idx=result_idx,
target=target, target=target,
func_or_mod=func_or_mod, func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source, host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags, compile_flags=compile_flags,
...@@ -394,7 +420,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -394,7 +420,7 @@ class JITKernel(Generic[_P, _T]):
return Profiler(self.params, self.out_idx, return Profiler(self.params, self.out_idx,
tensor_supply_type).with_default_adapter(self.adapter) tensor_supply_type).with_default_adapter(self.adapter)
def get_kernel_source(self) -> str: def get_kernel_source(self, kernel_only: bool = True) -> str:
""" """
Returns the source code of the compiled kernel function. Returns the source code of the compiled kernel function.
...@@ -403,14 +429,17 @@ class JITKernel(Generic[_P, _T]): ...@@ -403,14 +429,17 @@ class JITKernel(Generic[_P, _T]):
str str
The source code of the compiled kernel function. The source code of the compiled kernel function.
""" """
if self.execution_backend in {"ctypes", "cython", "nvrtc"}: if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}:
return self.adapter.get_kernel_source() return self.adapter.get_kernel_source(kernel_only=kernel_only)
return self.artifact.kernel_source return self.artifact.kernel_source
def get_host_source(self) -> str: def get_host_source(self) -> str:
""" """
Returns the source code of the host function. Returns the source code of the host function.
""" """
if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}:
return self.adapter.get_host_source()
assert self.artifact.host_mod is not None, "host_mod is not available"
return str(self.artifact.host_mod) return str(self.artifact.host_mod)
def run_once(self, func: Callable | None = None) -> None: def run_once(self, func: Callable | None = None) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment