Commit 2ac51a03 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Wrap] Use a ctypes-based kernel wrapper instead of dlpack for runtime efficiency (#95)

* bump version into v0.1.0

* [Enhancement] Add custom develop command for editable installs and update .gitignore

* [Documentation] Update README to include system dependencies installation instructions

* [Build] Update setup.py to support library file copying for both release and develop modes

* [Build] Refactor library file copying logic in setup.py

* [Documentation] Remove unnecessary install section header in Installation.md

* [Build] Add tox configuration and local distribution script for multi-Python version support

* [Build] Improve git submodule update function with better error handling

* [Build] Update LLVM configuration path in ROCm installation script

* [Build] Add .tox/ to .gitignore for tox testing environment

* [Build] Add support for TVM prebuild path configuration in CMakeLists.txt

* [Cleanup] Remove unused TVM runtime error codes header

* [Cleanup] Fix TVM grid constant type reference in CUDA module

* [Cleanup] Remove unused customized_code function from IR module

* [Feature] Add TileLang thread synchronization and storage access analysis passes

* [Build] Reorder DLL search path directories for more flexible library loading

* [Refactor] Improve thread synchronization and library path handling

- Rename ThreadSync and TileLangThreadSync functions in C++ code
- Update Python docstring for ThreadSync with more detailed description
- Reorder library path detection in tilelang environment setup
- Minor comment and code cleanup in CUDA and warp specialization modules

* [Refactor] Improve thread synchronization code style and formatting

- Standardize pointer type spacing in storage_access.h and storage_access.cc
- Update whitespace and indentation in thread_storage_sync.cc
- Reorder include statements in thread_partial_sync.cc
- Minor code formatting improvements across thread synchronization files

* [Refactor] Fix global function registration for ThreadSync

- Correct global function registration to use ThreadSync instead of TileLangThreadSync
- Update TVM global registration to match recent refactoring efforts

* [Refactor] Simplify ThreadSync global function registration

- Remove unnecessary whitespace in global function registration
- Compact the TVM global registration line for ThreadSync

* [Feature] Add WebGPU code generation support in TileLang

- Implement WebGPU code generator (codegen_webgpu.cc and codegen_webgpu.h)
- Add WebGPU target support in lower.py and target.py
- Update CMakeLists.txt to include WebGPU codegen source files
- Introduce WebGPU-specific code generation for WGSL shader language

* [Refactor] Improve WebGPU code generation formatting and readability

- Enhance code formatting in codegen_webgpu.cc and codegen_webgpu.h
- Standardize pointer type spacing and indentation
- Improve line breaks and reduce line length for better readability
- Minor code style improvements in WebGPU code generation

* [Test] Add WebGPU matrix multiplication code generation test

- Implement test_webgpu_codegen.py for WebGPU matrix multiplication
- Add assert_gemm_codegen function to validate WebGPU code generation
- Include basic matrix multiplication kernel test case

* Update README with WebGPU codegen support announcement

* Support multi version pypi package build via tox

* Add support for CPU device backend with C code generation

- Introduce `is_cpu_device_backend` function to detect CPU backend with C code generation
- Modify `lower` function to handle special case of CPU device backend
- Update host and device call filtering for CPU backend
- Add conditional source code generation for C host target
- Extend JITKernel to support optional target_host parameter

* lint fix

* Enhance JIT kernel adapters with CTypes and Torch C++ backends

- Add CtypesKernelAdapter with dynamic library generation and kernel wrapping
- Implement TorchCPPKernelAdapter for CUDA kernel compilation
- Refactor BaseKernelAdapter to support more flexible initialization
- Improve error handling and argument processing in kernel adapters
- Update adapter initialization to support various execution backends

* Refactor and clean up code style in JIT CTypes adapter modules

- Apply consistent code formatting and whitespace in CTypes adapter files
- Remove unused imports and improve import organization
- Enhance readability of code in adapter, libgen, and wrapper modules
- Add missing whitespace and improve line breaks
- Minor linting and code style improvements across CTypes adapter files

* Add test for TileLang JIT GEMM with CTypes backend

- Implement comprehensive test for matrix multiplication using CTypes execution backend
- Create test functions for GEMM with float16 data type
- Add kernel source verification with custom callback
- Implement reference implementation using PyTorch for result validation
- Support various matrix multiplication configurations (transposition, block sizes)

* test fix

* Update TileLang JIT callback registration with override parameter

- Modify tilelang_callback_cuda_postproc to use @tvm.register_func(override=True)
- Ensure proper function registration with ability to replace existing implementations
parent fa9a19b0
......@@ -88,7 +88,7 @@ def run_gemm(
stramp = "&*(XS)"
@tvm.register_func
@tvm.register_func(override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang
import torch
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_func(override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -5,7 +5,7 @@
import tilelang as tl
import os
import os.path as osp
from typing import Union, Optional
from typing import Union, Optional, Callable
from tilelang import tvm as tvm
from tvm import tir, relay
from tvm.ir import CallingConv
......@@ -14,21 +14,36 @@ from tilelang.contrib import hipcc, nvcc
from tilelang.utils.target import determine_target
def is_device_call(func: tir.PrimFunc):
def is_cpu_device_backend(target: Target):
return target.kind.name == "c"
def has_device_kernel_launch(attrs) -> bool:
"""Check if the attributes indicate a device kernel launch."""
return bool(attrs and "calling_conv" in attrs and
attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)
def is_device_call_c_device(func: tir.PrimFunc):
attrs = func.attrs
# consider c source as a device call
if "target" in attrs:
target = attrs["target"]
if target.kind.name == "c":
return True
# Check if it's a C target
if "target" in attrs and attrs["target"].kind.name == "c":
return True
return has_device_kernel_launch(attrs)
return bool(func.attrs and "calling_conv" in func.attrs and
func.attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)
def is_device_call(func: tir.PrimFunc):
return has_device_kernel_launch(func.attrs)
def get_device_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return is_device_call_c_device if is_device_c else is_device_call
def is_host_call(func: tir.PrimFunc):
return not is_device_call(func)
def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return lambda func: not get_device_call(is_device_c)(func)
@tvm.register_func("tilelang_callback_cuda_compile", override=True)
......@@ -134,6 +149,9 @@ def lower(
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)
_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))
mod = tir.transform.BindTarget(target)(mod)
mod = tl.transform.FrontendLegalize()(mod)
......@@ -196,7 +214,7 @@ def lower(
mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
host_mod = tir.transform.Filter(is_host_call)(mod)
host_mod = tir.transform.Filter(_is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
......@@ -209,11 +227,14 @@ def lower(
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
if is_cpu_device_backend(target):
host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
else:
host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
else:
raise ValueError("Target host is not supported")
raise ValueError(f"Target host {target_host.kind.name} is not supported")
device_mod = tir.transform.Filter(is_device_call)(mod)
device_mod = tir.transform.Filter(_is_device_call)(mod)
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
......@@ -231,10 +252,18 @@ def lower(
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError("Target is not supported")
raise ValueError(f"Target {target.kind.name} is not supported")
host_mod.import_module(device_mod)
if target_host.kind.name == "c":
# cpu host should be recompiled
# TODO(lei): this is a hack to make the C host backend work
temp_dir = tvm.contrib.utils.tempdir()
tmp_lib_path = temp_dir.relpath("tmp.so")
host_mod.export_library(tmp_lib_path)
host_mod = tvm.runtime.load_module(tmp_lib_path)
if runtime_only is True:
return host_mod
else:
......
......@@ -3,4 +3,5 @@
from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .torch_cpp import TorchCPPKernelAdapter # noqa: F401
from .torchcpp import TorchCPPKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
......@@ -2,16 +2,23 @@
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from typing import Any, List
from abc import ABC, abstractmethod
from typing import Any, List, Callable, Optional
from tvm.relay import TensorType
class BaseKernelAdapter(object):
class BaseKernelAdapter(ABC):
func: Optional[Callable] = None
def __init__(self, mod, params: List[TensorType], result_idx: List[int]) -> None:
self.mod = mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self._post_init()
def _legalize_result_idx(self, result_idx: List[int]) -> List[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
result_idx = []
......@@ -25,15 +32,17 @@ class BaseKernelAdapter(object):
elif not isinstance(result_idx, list):
raise ValueError("result_idx should be a list of integers")
self.result_idx = result_idx
self.func = self._convert_torch_func()
return result_idx
@abstractmethod
def _convert_torch_func(self) -> callable:
raise NotImplementedError
pass
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.func(*args, **kwds)
def get_kernel_source(self) -> str:
return self.mod.imported_modules[0].get_source()
def _post_init(self):
self.func = self._convert_torch_func()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from typing import List
from .base import BaseKernelAdapter
from tvm.relay import TensorType
class CtypesKernelAdapter(BaseKernelAdapter):
target = "cuda"
prim_func = None
def __init__(self,
mod,
params: List[TensorType],
result_idx: List[int],
target,
prim_func,
verbose: bool = False):
self.target = target
self.prim_func = prim_func
self.verbose = verbose
super().__init__(mod, params, result_idx)
raise NotImplementedError("CtypesKernelAdapter is not implemented yet.")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .adapter import CtypesKernelAdapter # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
import torch
from ..base import BaseKernelAdapter
import ctypes
from typing import List, Optional, Union, Callable
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
from tvm import tir
from .wrapper import TLWrapper
from .libgen import LibraryGenerator
from tilelang.utils.target import determine_target
class CtypesKernelAdapter(BaseKernelAdapter):
target = "cuda"
ir_module = None
is_dynamic: bool = False
lib: Optional[ctypes.CDLL] = None
def __init__(self,
rt_mod,
params: List[TensorType],
result_idx: List[int],
target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
is_dynamic: bool = False,
verbose: bool = False):
self.mod = rt_mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
self.ir_module = func_or_mod
self.target = Target.canon_target(determine_target(target))
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
wrapped_source = self.wrapper.wrap(self.get_kernel_source(), is_dynamic)
self.lib_generator.update_lib_code(wrapped_source)
self.lib_generator.compile_lib()
self.lib = self.lib_generator.load_lib()
self.lib.init()
self._post_init()
def _forward_from_prebuild_lib(self, *args, stream=0):
ctypes_args = [
ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args
]
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
def _warp_forward_from_prebuild_lib(self, *ins: List[torch.Tensor], stream=0):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
ins_idx = 0
args = []
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
for i in range(len(self.params)):
if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
self._forward_from_prebuild_lib(*args)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable:
return self._warp_forward_from_prebuild_lib
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from .utils import is_cuda_target, is_hip_target
from tilelang import tvm as tvm
from tilelang.contrib.nvcc import get_target_compute_version
from tvm.target import Target
import ctypes
import os
import tempfile
import subprocess
import logging
from tilelang.env import TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR
logger = logging.getLogger(__name__)
class LibraryGenerator(object):
srcpath: Optional[str] = None
libpath: Optional[str] = None
lib_code: Optional[str] = None
def __init__(self, target: Target):
self.target = target
def update_lib_code(self, lib_code: str):
self.lib_code = lib_code
# Assume currently we only support CUDA compilation
def load_lib(self):
return ctypes.CDLL(self.libpath)
def compile_lib(self, timeout: float = None, with_tl: bool = True):
target = self.target
if is_cuda_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split("."))
libpath = src.name.replace(".cu", ".so")
command = [
"nvcc",
"-std=c++17",
"-Xcudafe",
"--diag_suppress=177",
"--compiler-options",
"'-fPIC'",
"-lineinfo",
"--shared",
src.name,
"-lcuda",
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]
elif is_hip_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so")
command = [
"hipcc",
"-std=c++17",
"-fPIC",
"--shared",
src.name,
]
else:
raise ValueError(f"Unsupported target: {target}")
if with_tl:
command += [
"-I" + TILELANG_TEMPLATE_PATH,
"-I" + CUTLASS_INCLUDE_DIR,
]
command += ["-diag-suppress=20013"]
command += ["-o", libpath]
src.write(self.lib_code)
src.flush()
try:
ret = subprocess.run(command, timeout=timeout)
except subprocess.TimeoutExpired:
logger.warning(f"Compilation Timeout! {command}")
return None
if ret.returncode != 0:
logger.warning(f"Compilation Failed! {command}")
return None
self.srcpath = src.name
self.libpath = libpath
def remove_lib(self):
if self.libpath:
os.remove(self.libpath)
self.libpath = None
def get_source_path(self):
return self.srcpath
def get_lib_path(self):
return self.libpath
def set_lib_path(self, libpath):
self.libpath = libpath
def set_src_path(self, srcpath):
self.srcpath = srcpath
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
from typing import Union, Optional
from tilelang import tvm as tvm
from tvm import IRModule, tir
from tvm.target import Target
import tilelang.transform
from tilelang.engine.lower import (
is_device_call,
determine_target,
canon_target_host,
)
def match_global_kernel(source: str) -> int:
pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+"
matched = re.findall(pattern, source)
assert len(matched) >= 1 # may have statement before kernel
return source.index(matched[0])
def is_cuda_target(target: Target) -> bool:
return target.kind.name == "cuda"
def is_hip_target(target: Target) -> bool:
return target.kind.name == "hip"
def get_annotated_device_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None,
) -> "IRModule":
mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc):
func = func_or_mod
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
if isinstance(target, str):
target = determine_target(target)
target_host = canon_target_host(target, target_host)
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)
mod = tir.transform.BindTarget(target)(mod)
mod = tilelang.transform.FrontendLegalize()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.LayoutInference()(mod)
mod = tilelang.transform.LowerTileOp()(mod)
mod = tir.transform.Simplify()(mod)
if target.arch == "sm_90":
mod = tilelang.transform.WarpSpecializedPipeline()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
mod = tir.transform.ThreadSync("shared")(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tir.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tir.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tir.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
device_mod = tir.transform.Filter(is_device_call)(mod)
return device_mod
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod
from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union
from tvm import IRModule
from tvm.target import Target
from .utils import match_global_kernel, is_cuda_target, is_hip_target, get_annotated_device_mod
import re
import logging
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
"""
PREDEF_INIT_FUNC = """
extern "C" void init() {{
{}
}}
"""
PREDEF_HOST_FUNC = """
extern "C" void call({}) {{
{}
}}
"""
class BaseWrapper(ABC):
@abstractmethod
def wrap(self, *args, **kwargs):
raise NotImplementedError
logger = logging.getLogger(__name__)
class TLCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"float64": "double",
"int64": "int64_t",
"int32": "int",
"uint32": "unsigned int",
"bool": "int8_t",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uchar": "uint8_t",
}
backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.function_name: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
def parse_source_information(self):
device_mod = get_annotated_device_mod(self.mod, self.target)
assert (len(device_mod.functions) == 1
), "Only support one function in the module for static shape kernel."
for g_var, func in device_mod.functions.items():
self.function_name = g_var.name_hint
attrs = func.attrs
if "dyn_shared_memory_buf" in attrs:
self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs:
thread_extent = attrs["thread_extent"]
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
self.grid_info["xyz".index(tag[-1])] = extent
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set = set()
for param in prim_func.params:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var):
dynamic_symbolic_set.add(dim.name)
return dynamic_symbolic_set
def get_cuda_init_func(self):
# Initialize an empty string for the CUDA function call
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = (
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf))
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Find the index of the global kernel function in the code
index = match_global_kernel(code)
# Extract the declaration of the function starting from the found index
declaration = code[index:].split(";")[0]
function_name = self.function_name
# Get the CUDA initialization function
init_func = self.get_cuda_init_func()
# Locate the opening brace of the function to insert arguments
index = code.index("{", index)
function_args = []
# Populate the function arguments from the primary function's parameters and buffers
for param in self.prim_func.params:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
# Add dynamic symbolic parameters as integers to the function arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for match in matches:
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
return call_args
call_args = ", ".join(func_call_args(declaration, function_args))
block_info, grid_info = self.block_info, self.grid_info
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
# Prepare the block and grid dimensions for the CUDA kernel launch
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
# Determine the shared memory size, defaulting to 0 if not specified
smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf
# Format the CUDA kernel launch string
if len(dynamic_symbolic_set) != 0:
call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0])
else:
call_str = ""
call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel
host_func = PREDEF_HOST_FUNC.format(def_args, call_str)
# Combine the source, initialization function, and host function to form the complete library code
lib_code = self.source + init_func + host_func
return lib_code
@property
def prim_func(self):
if len(self.mod.get_global_vars()) == 1:
return self.mod[self.mod.get_global_vars()[0]]
elif "main" in self.mod:
return self.mod["main"]
else:
for _, function in self.mod.functions_items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLCUDASourceWrapperWithDynamic(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
super().__init__(scheduled_ir_module, source, target)
def get_cuda_init_func(self):
# Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory
call_str = """"""
# Iterate over functions and their dynamic shared memory requirements
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(
function_name, dynamic_smem_buf)
# Define the init function that will set the attributes for each kernel
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def create_dispatch_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
# Find the location of the global kernel function in the code
index = match_global_kernel(code)
# Analyze the function declaration to prepare for argument extraction
dummy_declaration = code[index:].split(";")[0]
function_name = self.function_name
# Identify the start of the function body to insert arguments
index = code.index("{", index)
function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
# Format the argument definitions for function declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s: str, function_args):
# Extract and clean the function call arguments to match the declaration
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for match in matches:
match = re.sub(r"\d+", "", match) # Remove numbers
match = re.sub(r"_", "", match) # Remove underscores
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
return call_args
call_args = ", ".join(func_call_args(dummy_declaration, function_args))
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
last_range = 0
num_items = len(function_informations)
_call_str = """"""
for last_range, (function_name, info) in enumerate(function_informations.items()):
# Prepare block and grid configurations for kernel launches
block_info, grid_info = info["block_info"], info["grid_info"]
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]),
legalize_c(grid_info[1]),
legalize_c(grid_info[2]),
)
# Handle dynamic shared memory specification
smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"])
opt_shapes = info["opt_shapes"]
# Generate conditional kernel launch code based on dynamic symbolic ranges
(symbolic,) = list(dynamic_symbolic_set)
range_str = opt_shapes[symbolic]
if last_range == 0:
call_str = " if ({} == 0) return; \n".format(symbolic,)
call_str += " if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
symbolic,
range_str,
function_name,
grid_str,
block_str,
smem_str,
call_args,
)
else:
call_str = " else if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
symbolic,
range_str,
function_name,
grid_str,
block_str,
smem_str,
call_args,
)
if last_range == num_items - 1:
call_str += " else {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
function_name, grid_str, block_str, smem_str, call_args)
_call_str += call_str
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
return host_func
def parse_source_information(self):
# Parse device module to extract execution configurations for each function
device_mod = get_annotated_device_mod(self.mod, self.target, backend=self.backend)
block_info_map = {}
grid_info_map = {}
dynamic_smem_buf_map = {}
for g_var, func in device_mod.functions.items():
# Default block and grid configurations
block_info = [1, 1, 1]
grid_info = [1, 1, 1]
function_name = g_var.name_hint
attrs = func.attrs
dynamic_smem_buf = None
if "dyn_shared_memory_buf" in attrs:
dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs:
# Extract block and grid sizes from thread extents
thread_extent = attrs["thread_extent"]
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
grid_info["xyz".index(tag[-1])] = extent
# Map the extracted configurations to each function
block_info_map[function_name] = block_info
grid_info_map[function_name] = grid_info
dynamic_smem_buf_map[function_name] = dynamic_smem_buf
# Store the mappings for use in code generation
self.block_info = block_info_map
self.grid_info = grid_info_map
self.dynamic_smem_buf = dynamic_smem_buf_map
def update_lib_code(self, code: str):
# Organize function information for code generation
function_informations = {}
for g_var, func in self.mod.functions.items():
function_name = g_var.name_hint
# Do not update function with dispatch host function
if (function_name not in self.block_info) or (function_name not in self.grid_info):
continue
attrs = func.attrs
assert "opt_shapes" in attrs
opt_shapes = attrs["opt_shapes"]
function_informations[function_name] = {
"function_name": function_name,
"opt_shapes": opt_shapes,
"block_info": self.block_info[function_name],
"grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
}
def compare_map_objects(map_obj):
comparable_representation = list(map_obj.values())
return comparable_representation
function_informations = dict(
sorted(
function_informations.items(),
key=lambda item: compare_map_objects(item[1]["opt_shapes"]),
))
self.lib_code = code
# Generate the initialization and dispatch functions
init_func = self.get_cuda_init_func()
host_func = self.create_dispatch_func(code, function_informations)
# Concatenate source code with generated code segments
lib_code = self.source + init_func + host_func
return lib_code
class TLHIPSourceWrapper(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
super().__init__(scheduled_ir_module, source, target)
def get_hip_init_func(self):
# Initialize an empty string for the CUDA function call
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf)
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def get_stream_type(self, function_args):
function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},)
class TLWrapper(BaseWrapper):
def __init__(self, target: Target):
super().__init__()
self.scheduled_ir_module = None
self.target = target
self.lib = None
def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module
# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str, is_dynamic: bool = False):
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target):
wrapper_class = (
TLCUDASourceWrapper if not is_dynamic else TLCUDASourceWrapperWithDynamic)
elif is_hip_target(self.target):
wrapper_class = TLHIPSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target)
return wrapper.lib_code
......@@ -7,7 +7,7 @@ import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tilelang.jit.adapter import TorchCPPKernelAdapter, TorchDLPackKernelAdapter, BaseKernelAdapter
from tilelang.jit.adapter import TorchCPPKernelAdapter, TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.profiler import Profiler, TensorSupplyType
......@@ -36,6 +36,7 @@ class JITKernel(object):
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
):
"""
......@@ -51,6 +52,8 @@ class JITKernel(object):
Execution backend to use for kernel execution (default: "dlpack").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation (default: None).
verbose : bool, optional
Whether to enable verbose output (default: False).
"""
......@@ -58,6 +61,7 @@ class JITKernel(object):
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
self.verbose = verbose
# If the target is specified as a string, validate it and convert it to a TVM Target.
......@@ -113,12 +117,13 @@ class JITKernel(object):
"""
verbose = self.verbose
target = self.target
target_host = self.target_host
out_idx = self.out_idx
execution_backend = self.execution_backend
# Compile the function with TVM, optimizing with shared memory lowering.
with tvm.transform.PassContext(opt_level=3):
rt_mod, params = tilelang.lower(tilelang_func, target=target)
rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host)
# Store the runtime module and parameters for later use.
self.rt_module = rt_mod
......@@ -140,8 +145,15 @@ class JITKernel(object):
)
raise NotImplementedError("Torch CPP backend is not fully implemented.")
elif execution_backend == "ctypes":
# CTYPES backend (not implemented yet).
raise NotImplementedError("CTypes backend is not implemented.")
# CTYPES backend (not fully tested yet).
adapter = CtypesKernelAdapter(
rt_mod,
params=params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
verbose=verbose,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
......@@ -195,5 +207,11 @@ class JITKernel(object):
"""
return self.rt_module.imported_modules[0].get_source()
def get_host_source(self) -> str:
"""
Returns the source code of the host function.
"""
return self.rt_module.get_source()
def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment