"...composable_kernel.git" did not exist on "3ba485b61ae0894f24c74ae350598746befb5aab"
Commit 3471904f authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[JIT] Support Cython jit and make cython a default execution backend (#102)

* [Feature] Add CTypes JIT kernel support for dynamic shapes and multi-stream execution

- Enhance CtypesKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in CTypes backend
- Implement dynamic shape handling in test_tilelang_jit_gemm_ctypes.py
- Add symbolic shape utility function in tilelang.language
- Update profiler to improve flexibility in benchmark selection

* Remove redundant thread binding in GEMM kernel implementations

- Remove unnecessary `thread_binding` line in GEMM kernel functions
- Clean up code in `examples/gemm/README.md` and `testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py`
- Enhance code readability by removing redundant thread binding annotation

* Fix indentation in int4 GEMM kernel test file

- Correct indentation for function calls in `test_tilelang_kernel_int4_gemm_mma.py`
- Remove extra indentation in `mma_emitter.ldmatrix_a()` and `mma_emitter.ldmatrix_b()` calls
- Improve code formatting for better readability

* [Feature] Add Cython JIT kernel support for dynamic shapes and multi-stream execution

- Implement CythonKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in Cython backend
- Create comprehensive test suite for Cython GEMM kernel in test_tilelang_jit_gemm_cython.py
- Update JITKernel to include "cython" as a valid execution backend
- Add Cython-specific wrapper and library generation modules
- Update .gitignore to exclude Cython cache directory
- Modify setup.py to include Cython source files in package data

* lint fix

* [Refactor] Replace JITKernel with compile() function for kernel compilation

- Add new `compile()` function in tilelang/jit/__init__.py as a wrapper for JITKernel
- Update multiple test files and examples to use `tilelang.compile()` instead of `tilelang.JITKernel()`
- Modify kernel adapters to support optional kernel-only source retrieval
- Update `__init__.py` to import the new `compile()` function
- Improve kernel source retrieval for different execution backends

* lint fix

* remove debug print

* Add C/C++ compiler utility module and update Cython JIT kernel support

- Introduce new `tilelang/contrib/cc.py` module with cross-platform C/C++ compiler utilities
- Add functions to detect and retrieve system C/C++ compilers
- Implement cross-compilation and shared library creation support
- Update Cython JIT kernel to validate C++ compiler availability
- Modify Cython adapter to use detected C++ compiler for library generation

* Refactor float8 dtype mapping in tensor utility module

- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray

* Refactor float8 dtype mapping in tensor utility module

- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray

* revert

* Enhance Cython JIT adapter with Cython compiler detection

- Add `get_cython_compiler()` function to dynamically locate Cython executable
- Update Cython adapter to use detected Cython compiler instead of hardcoded command
- Raise an exception if no Cython compiler is found
- Update requirements.txt to specify minimum PyTorch version (>=2.2.0)

* Fix Cython kernel wrapper stream handling and type annotations

- Update stream parameter type to int64_t for better compatibility
- Directly use torch.cuda.current_stream().cuda_stream instead of casting
- Improve type safety and precision in Cython kernel wrapper
parent 8d450c34
......@@ -85,3 +85,6 @@ tilelang/lib
# tox
.tox/
# cython
tilelang/jit/adapter/cython/.cycache
......@@ -154,7 +154,7 @@ func = matmul(1024, 1024, 1024, 128, 128, 32)
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
# 3. Test the kernel in Python with PyTorch data
import torch
......
......@@ -109,7 +109,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. JIT-compile the kernel for NVIDIA GPU
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
import torch
......
......@@ -63,7 +63,7 @@ func = matmul(1024, 1024, 1024, 128, 128, 32)
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
# 3. Test the kernel in Python with PyTorch data
import torch
......
......@@ -151,7 +151,7 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
package_data = {
"tilelang": ["py.typed"],
"tilelang": ["py.typed", "*pyx"],
}
LLVM_VERSION = "10.0.1"
......@@ -227,7 +227,22 @@ class TileLangBuilPydCommand(build_py):
ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}")
print(f"Build temp directory: {build_temp_dir}")
# copy cython files
CYTHON_SRC = [
"tilelang/jit/adapter/cython/cython_wrapper.pyx",
]
for item in CYTHON_SRC:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# copy the tl_templates
TILELANG_SRC = [
"src/tl_templates",
]
......
......@@ -14,7 +14,7 @@ def debug_print_buffer(M=16, N=16):
shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf)
jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
......@@ -34,7 +34,7 @@ def debug_print_buffer_conditional(M=16, N=16):
if bx == 0 and by == 0 and bz == 0:
T.print(shared_buf)
jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
......@@ -53,7 +53,7 @@ def debug_print_value_conditional(M=16, N=16):
if tid == 0:
T.print(bx + by + bz)
jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
......@@ -72,7 +72,7 @@ def debug_print_register_files(M=16, N=16):
for i, j in T.Parallel(M, N):
T.print(register_buf[i, j])
jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
......@@ -91,7 +91,7 @@ def debug_print_msg(M=16, N=16):
if tid == 0:
T.print(bx + by + bz, msg="hello world")
jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
......
......@@ -42,7 +42,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32):
func = matmul(N, N, N, block_M, block_N, block_K)
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
torch.manual_seed(0)
a = torch.randn(N, N, device="cuda", dtype=torch.float16)
......
......@@ -93,7 +93,7 @@ def run_gemm(
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")
kernel_source = matmul_kernel.get_kernel_source()
......@@ -196,7 +196,7 @@ def run_gemm_jit_kernel(
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
......@@ -206,7 +206,7 @@ def run_gemm_jit_kernel(
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
......@@ -92,7 +92,7 @@ def run_gemm(
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes")
kernel_source = matmul_kernel.get_kernel_source()
......@@ -195,7 +195,7 @@ def run_gemm_jit_kernel(
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......@@ -263,7 +263,7 @@ def run_ctypes_kernel_do_bench(M,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")
profiler = matmul_kernel.get_profiler()
......@@ -312,7 +312,7 @@ def run_ctypes_kernel_multi_stream(M,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......@@ -364,7 +364,7 @@ def run_ctypes_dynamic_shape(M,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def run_cython_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
cython_matmul_kernel = tilelang.compile(program, execution_backend="cython")
ctypes_matmul_kernel = tilelang.compile(program, execution_backend="ctypes")
cython_profiler = cython_matmul_kernel.get_profiler()
ctypes_profiler = ctypes_matmul_kernel.get_profiler()
cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel, profiler="torch")
print(f"cython Latency: {cython_latency} ms")
# assert ctypes_latency is not None
tvm_latency = cython_profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")
assert tvm_latency is not None
ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel, profiler="torch")
print(f"ctypes Latency: {ctypes_latency} ms")
assert cython_latency is not None
def test_cython_kernel_do_bench():
run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
def run_cython_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="cython")
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
num_streams = 4
for _ in range(num_streams):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
matmul_kernel(tensor_a, tensor_b, tensor_c)
def test_cython_kernel_multi_stream():
run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)
def run_cython_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="cython")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cython_dynamic_shape():
run_cython_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -105,7 +105,7 @@ def _load_tile_lang_lib():
if SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib()
from .jit import jit, JITKernel # noqa: F401
from .jit import jit, JITKernel, compile # noqa: F401
from .profiler import Profiler # noqa: F401
from .utils import (
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Util to invoke C/C++ compilers in the system."""
import os
import shutil
import subprocess
# pylint: disable=invalid-name
import sys
from typing import Dict
from tvm._ffi.base import py_str
from tvm.contrib import tar as _tar
from tvm.contrib import utils as _utils
def _is_linux_like():
return (sys.platform == "darwin" or sys.platform.startswith("linux") or
sys.platform.startswith("freebsd"))
def _is_windows_like():
return sys.platform == "win32"
def get_cc():
"""Return the path to the default C/C++ compiler.
Returns
-------
out: Optional[str]
The path to the default C/C++ compiler, or None if none was found.
"""
if not _is_linux_like():
return None
env_cxx = os.environ.get("CXX") or os.environ.get("CC")
if env_cxx:
return env_cxx
cc_names = ["g++", "gcc", "clang++", "clang", "c++", "cc"]
dirs_in_path = os.get_exec_path()
for cc in cc_names:
for d in dirs_in_path:
cc_path = os.path.join(d, cc)
if os.path.isfile(cc_path) and os.access(cc_path, os.X_OK):
return cc_path
return None
def get_cplus_compiler():
"""Return the path to the default C/C++ compiler.
Returns
-------
out: Optional[str]
The path to the default C/C++ compiler, or None if none was found.
"""
if not _is_linux_like():
return None
env_cxx = os.environ.get("CXX") or os.environ.get("CC")
if env_cxx:
return env_cxx
cc_names = ["g++", "clang++", "c++"]
dirs_in_path = os.get_exec_path()
for cc in cc_names:
for d in dirs_in_path:
cc_path = os.path.join(d, cc)
if os.path.isfile(cc_path) and os.access(cc_path, os.X_OK):
return cc_path
return None
def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
"""Create shared library.
Parameters
----------
output : str
The target shared library.
objects : List[str]
List of object files.
options : List[str]
The list of additional options string.
cc : Optional[str]
The compiler command.
cwd : Optional[str]
The current working directory.
ccache_env : Optional[Dict[str, str]]
The environment variable for ccache. Set `None` to disable ccache by default.
"""
cc = cc or get_cc()
if _is_linux_like():
_linux_compile(output, objects, options, cc, cwd, ccache_env, compile_shared=True)
elif _is_windows_like():
_windows_compile(output, objects, options, cwd, ccache_env)
else:
raise ValueError("Unsupported platform")
def _linux_ar(output, inputs, ar):
ar = ar or "ar"
libname = os.path.basename(output)
if not libname.startswith("lib"):
libname = "lib" + libname
temp = _utils.tempdir()
temp_output = temp.relpath(libname)
cmd = [ar, "-crs", temp_output]
# handles the case where some input files are tar of objects
# unpack them and return the list of files inside
objects = _tar.normalize_file_list_by_unpacking_tars(temp, inputs)
cmd += objects
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "AR error:\n"
msg += py_str(out)
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)
shutil.move(temp_output, output)
def create_staticlib(output, inputs, ar=None):
"""Create static library.
Parameters
----------
output : str
The target shared library.
inputs : List[str]
List of inputs files. Each input file can be a tarball
of objects or an object file.
ar : Optional[str]
Path to the ar command to be used
"""
if _is_linux_like():
return _linux_ar(output, inputs, ar)
else:
raise ValueError("Unsupported platform")
def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
"""Create executable binary.
Parameters
----------
output : str
The target executable.
objects : List[str]
List of object files.
options : List[str]
The list of additional options string.
cc : Optional[str]
The compiler command.
cwd : Optional[str]
The urrent working directory.
ccache_env : Optional[Dict[str, str]]
The environment variable for ccache. Set `None` to disable ccache by default.
"""
cc = cc or get_cc()
if _is_linux_like():
_linux_compile(output, objects, options, cc, cwd, ccache_env)
elif _is_windows_like():
_windows_compile(output, objects, options, cwd, ccache_env)
else:
raise ValueError("Unsupported platform")
def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]:
"""Get global symbols from a library via nm -g
Parameters
----------
path : str
The library path
nm: str
The path to nm command
Returns
-------
symbol_section_map: Dict[str, str]
A map from defined global symbol to their sections
"""
if nm is None:
if not _is_linux_like():
raise ValueError("Unsupported platform")
nm = "nm"
symbol_section_map = {}
if not os.path.isfile(path):
raise FileNotFoundError(f"{path} does not exist")
cmd = [nm, "-gU", path]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Runtime error:\n"
msg += py_str(out)
raise RuntimeError(msg)
for line in py_str(out).split("\n"):
data = line.strip().split()
if len(data) != 3:
continue
symbol = data[-1]
section = data[-2]
symbol_section_map[symbol] = section
return symbol_section_map
def get_target_by_dump_machine(compiler):
"""Functor of get_target_triple that can get the target triple using compiler.
Parameters
----------
compiler : Optional[str]
The compiler.
Returns
-------
out: Callable
A function that can get target triple according to dumpmachine option of compiler.
"""
def get_target_triple():
"""Get target triple according to dumpmachine option of compiler."""
if compiler:
cmd = [compiler, "-dumpmachine"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "dumpmachine error:\n"
msg += py_str(out)
return None
return py_str(out)
return None
return get_target_triple
# assign so as default output format
create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared.get_target_triple = get_target_by_dump_machine(os.environ.get("CXX", get_cc()))
def cross_compiler(compile_func,
options=None,
output_format=None,
get_target_triple=None,
add_files=None):
"""Create a cross compiler function by specializing compile_func with options.
This function can be used to construct compile functions that
can be passed to AutoTVM measure or export_library.
Parameters
----------
compile_func : Union[str, Callable[[str, str, Optional[str]], None]]
Function that performs the actual compilation
options : Optional[List[str]]
List of additional optional string.
output_format : Optional[str]
Library output format.
get_target_triple: Optional[Callable]
Function that can target triple according to dumpmachine option of compiler.
add_files: Optional[List[str]]
List of paths to additional object, source, library files
to pass as part of the compilation.
Returns
-------
fcompile : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library.
Examples
--------
.. code-block:: python
from tvm.contrib import cc, ndk
# export using arm gcc
mod = build_runtime_module()
mod.export_library(path_dso,
fcompile=cc.cross_compiler("arm-linux-gnueabihf-gcc"))
# specialize ndk compilation options.
specialized_ndk = cc.cross_compiler(
ndk.create_shared,
["--sysroot=/path/to/sysroot", "-shared", "-fPIC", "-lm"])
mod.export_library(path_dso, fcompile=specialized_ndk)
"""
base_options = [] if options is None else options
kwargs = {}
add_files = [] if add_files is None else add_files
# handle case where compile_func is the name of the cc
if isinstance(compile_func, str):
kwargs = {"cc": compile_func}
compile_func = create_shared
def _fcompile(outputs, objects, options=None):
all_options = base_options
if options is not None:
all_options += options
compile_func(outputs, objects + add_files, options=all_options, **kwargs)
if not output_format and hasattr(compile_func, "output_format"):
output_format = compile_func.output_format
output_format = output_format if output_format else "so"
if not get_target_triple and hasattr(compile_func, "get_target_triple"):
get_target_triple = compile_func.get_target_triple
_fcompile.output_format = output_format
_fcompile.get_target_triple = get_target_triple
return _fcompile
def _linux_compile(output,
objects,
options,
compile_cmd,
cwd=None,
ccache_env=None,
compile_shared=False):
cmd = [compile_cmd]
if compile_cmd != "nvcc":
if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
cmd += ["-shared", "-fPIC"]
if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"]
elif output.endswith(".obj"):
cmd += ["-c"]
else:
if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
cmd += ["-shared"]
cmd += ["-o", output]
if isinstance(objects, str):
cmd += [objects]
else:
cmd += objects
if options:
cmd += options
env = None
if ccache_env is not None:
if shutil.which("ccache"):
cmd.insert(0, "ccache")
env = os.environ.copy()
env.update(ccache_env)
else:
raise ValueError("ccache not found")
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += py_str(out)
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)
def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
cmd = ["clang"]
cmd += ["-O2"]
if output.endswith(".so") or output.endswith(".dll"):
cmd += ["-shared"]
elif output.endswith(".obj"):
cmd += ["-c"]
if isinstance(objects, str):
objects = [objects]
cmd += ["-o", output]
cmd += objects
if options:
cmd += options
env = None
if ccache_env is not None:
if shutil.which("ccache"):
cmd.insert(0, "ccache")
env = os.environ.copy()
env.update(ccache_env)
else:
raise ValueError("ccache not found")
try:
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env)
(out, _) = proc.communicate()
except FileNotFoundError:
raise RuntimeError("Can not find the LLVM clang for Windows clang.exe)."
"Make sure it's installed"
" and the installation directory is in the %PATH% environment "
"variable. Prebuilt binaries can be found at: https://llvm.org/") \
from None
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += " ".join(cmd) + "\n"
msg += py_str(out)
raise RuntimeError(msg)
......@@ -105,3 +105,23 @@ def jit(
return _compile_and_create_adapter(tilelang_func)
return real_decorator
def compile(
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "torch_cpp", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
) -> JITKernel:
"""
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
"""
return JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose)
......@@ -5,3 +5,4 @@ from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .torchcpp import TorchCPPKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401
......@@ -67,7 +67,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source())
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source)
self.lib_generator.compile_lib()
......@@ -185,3 +185,11 @@ class CtypesKernelAdapter(BaseKernelAdapter):
def is_dynamic(self):
"""Indicates whether the kernel handles dynamic shapes."""
return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0)
def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel."""
if kernel_only:
return self.mod.imported_modules[0].get_source()
else:
assert self.wrapped_source is not None, "Wrapped source is not available"
return self.wrapped_source
......@@ -175,10 +175,11 @@ class TLCUDASourceWrapper(object):
# Determine the shared memory size, defaulting to 0 if not specified
smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf
# Format the CUDA kernel launch string
call_str = ""
if len(dynamic_symbolic_set) != 0:
call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0])
call_str += "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0])
else:
call_str = ""
call_str += ""
call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .adapter import CythonKernelAdapter # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from ..base import BaseKernelAdapter
import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
from tvm import tir
from .wrapper import TLWrapper
from .libgen import LibraryGenerator
from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
from tilelang.contrib.cc import get_cplus_compiler
import sys
import sysconfig
import hashlib
import os
from pathlib import Path
import logging
logger = logging.getLogger("tilelang")
def get_cython_compiler() -> Optional[str]:
"""Return the path to the Cython compiler.
Returns
-------
out: Optional[str]
The path to the Cython compiler, or None if none was found.
"""
cython_names = ["cython", "cython3"]
dirs_in_path = os.get_exec_path()
for cython_name in cython_names:
for d in dirs_in_path:
cython_path = os.path.join(d, cython_name)
if os.path.isfile(cython_path) and os.access(cython_path, os.X_OK):
return cython_path
return None
# Add cache management functions at module level
def get_cache_dir() -> Path:
"""Get the cache directory for the current Python version."""
py_version = f"py{sys.version_info.major}{sys.version_info.minor}"
# current directory
current_dir = os.path.dirname(os.path.abspath(__file__))
cache_dir = Path(current_dir) / ".cycache" / py_version
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def get_cached_lib(source_code: str) -> Tuple[Optional[ctypes.CDLL], Path]:
"""Try to load cached library or return None if not found."""
code_hash = hashlib.sha256(source_code.encode()).hexdigest()
cache_path = get_cache_dir() / f"{code_hash}.so"
if cache_path.exists():
try:
return ctypes.CDLL(str(cache_path)), cache_path
except Exception as e:
logger.error(f"Failed to load cached library: {e}")
return None, cache_path
return None, cache_path
# read the cython_wrapper.pyx file
current_dir = os.path.dirname(os.path.abspath(__file__))
cython_wrapper_path = os.path.join(current_dir, "cython_wrapper.pyx")
with open(cython_wrapper_path, "r") as f:
cython_wrapper_code = f.read()
cache_dir = get_cache_dir()
source_path = cache_dir / "cython_wrapper.cpp"
library_path = cache_dir / "cython_wrapper.so"
md5_path = cache_dir / "md5.txt"
code_hash = hashlib.sha256(cython_wrapper_code.encode()).hexdigest()
# Check if cached version exists and is valid
need_compile = True
if md5_path.exists() and library_path.exists():
with open(md5_path, "r") as f:
cached_hash = f.read().strip()
if cached_hash == code_hash:
logger.debug("Cython jit adapter is up to date, no need to compile...")
need_compile = False
else:
logger.info("Cython jit adapter is out of date, need to compile...")
else:
logger.info("No cached version found for cython jit adapter, need to compile...")
if need_compile:
logger.info("Compiling cython jit adapter...")
with open(md5_path, "w") as f:
f.write(code_hash)
# compile the cython_wrapper.pyx file into .cpp
cython = get_cython_compiler()
if cython is None:
raise Exception("Cython is not installed, please install it first.")
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
# compile the .cpp file into .so
python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {library_path}"
try:
os.system(command)
except Exception as e:
raise Exception(f"Failed to compile cython jit adapter: {e}") from e
# add the .so file to the sys.path
cache_dir_str = str(cache_dir)
if cache_dir_str not in sys.path:
sys.path.append(cache_dir_str)
from cython_wrapper import CythonKernelWrapper
class CythonKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes.
This adapter handles:
1. Converting TIR functions to compiled CUDA libraries
2. Managing dynamic shapes in tensor operations
3. Wrapping C++ kernels for Python/PyTorch usage
"""
# Class attributes to store compiled kernel information
target: str = "cuda"
ir_module: Optional[tvm.IRModule] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
def __init__(self,
rt_mod,
params: List[TensorType],
result_idx: List[int],
target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False):
"""Initialize the adapter with the given TIR function or module.
Args:
rt_mod: Runtime module
params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors
target: Target platform (e.g., 'cuda')
func_or_mod: TIR function or module to be compiled
verbose: Enable verbose logging
"""
self.mod = rt_mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
self.ir_module = func_or_mod
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.target = Target.canon_target(determine_target(target))
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source)
self.lib_generator.compile_lib()
self.lib = self.lib_generator.load_lib()
self.lib.init()
self.cython_wrapper = CythonKernelWrapper(self.dynamic_symbolic_map, self.result_idx,
self.params, self.lib)
self._post_init()
def _process_dynamic_symbolic(self):
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
dynamic_symbolic_map = {}
for i, param in enumerate(params):
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map):
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args = [
ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args
]
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
def _convert_torch_func(self) -> Callable:
"""Returns a PyTorch-compatible function wrapper for the kernel."""
def lambda_forward(*args, stream: int = -1):
return self.cython_wrapper.forward([*args], stream=stream)
return lambda_forward
@property
def prim_func(self) -> tir.PrimFunc:
"""Returns the primary TIR function from the IR module."""
return retrieve_func_from_module(self.ir_module)
@property
def srcpath(self):
"""Returns the source path of the compiled library."""
return self.lib_generator.srcpath
@property
def libpath(self):
"""Returns the path to the compiled library."""
return self.lib_generator.libpath
@property
def lib_code(self):
"""Returns the code of the compiled library."""
return self.lib_generator.lib_code
@property
def is_dynamic(self):
"""Indicates whether the kernel handles dynamic shapes."""
return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0)
def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel."""
if kernel_only:
return self.mod.imported_modules[0].get_source()
else:
assert self.wrapped_source is not None, "Wrapped source is not available"
return self.wrapped_source
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment