"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "fef9c0d9592bc101b7326589c5ba47d74984a448"
Commit 7c817d51 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Feature] Add CTypes JIT kernel support (#100)

* [Feature] Add CTypes JIT kernel support for dynamic shapes and multi-stream execution

- Enhance CtypesKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in CTypes backend
- Implement dynamic shape handling in test_tilelang_jit_gemm_ctypes.py
- Add symbolic shape utility function in tilelang.language
- Update profiler to improve flexibility in benchmark selection

* Remove redundant thread binding in GEMM kernel implementations

- Remove unnecessary `thread_binding` line in GEMM kernel functions
- Clean up code in `examples/gemm/README.md` and `testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py`
- Enhance code readability by removing redundant thread binding annotation

* Fix indentation in int4 GEMM kernel test file

- Correct indentation for function calls in `test_tilelang_kernel_int4_gemm_mma.py`
- Remove extra indentation in `mma_emitter.ldmatrix_a()` and `mma_emitter.ldmatrix_b()` calls
- Improve code formatting for better readability
parent 7cd6b3cd
......@@ -339,8 +339,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
......@@ -27,8 +28,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
......@@ -235,5 +234,171 @@ def test_gemm_jit_kernel():
)
def run_ctypes_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
profiler = matmul_kernel.get_profiler()
ctypes_latency = profiler.do_bench(func=matmul_kernel, profiler="torch")
print(f"Ctypes Latency: {ctypes_latency} ms")
assert ctypes_latency is not None
tvm_latency = profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")
assert tvm_latency is not None
def test_ctypes_kernel_do_bench():
run_ctypes_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
def run_ctypes_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
num_streams = 4
for _ in range(num_streams):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
matmul_kernel(tensor_a, tensor_b, tensor_c)
def test_ctypes_kernel_multi_stream():
run_ctypes_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)
def run_ctypes_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_ctypes_dynamic_shape():
run_ctypes_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -109,8 +109,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
......@@ -294,8 +292,6 @@ def tl_matmul_weight_only_transform(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
......
......@@ -5,7 +5,7 @@
import torch
from ..base import BaseKernelAdapter
import ctypes
from typing import List, Optional, Union, Callable
from typing import List, Optional, Union, Callable, Dict, Tuple
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
......@@ -13,14 +13,25 @@ from tvm import tir
from .wrapper import TLWrapper
from .libgen import LibraryGenerator
from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
class CtypesKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes.
This adapter handles:
1. Converting TIR functions to compiled CUDA libraries
2. Managing dynamic shapes in tensor operations
3. Wrapping C++ kernels for Python/PyTorch usage
"""
# Class attributes to store compiled kernel information
target = "cuda"
ir_module = None
is_dynamic: bool = False
lib: Optional[ctypes.CDLL] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
def __init__(self,
rt_mod,
......@@ -28,9 +39,17 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: List[int],
target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
is_dynamic: bool = False,
verbose: bool = False):
"""Initialize the adapter with the given TIR function or module.
Args:
rt_mod: Runtime module
params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors
target: Target platform (e.g., 'cuda')
func_or_mod: TIR function or module to be compiled
verbose: Enable verbose logging
"""
self.mod = rt_mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
......@@ -40,29 +59,69 @@ class CtypesKernelAdapter(BaseKernelAdapter):
else:
self.ir_module = func_or_mod
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.target = Target.canon_target(determine_target(target))
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
wrapped_source = self.wrapper.wrap(self.get_kernel_source(), is_dynamic)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source())
self.lib_generator.update_lib_code(wrapped_source)
self.lib_generator.update_lib_code(self.wrapped_source)
self.lib_generator.compile_lib()
self.lib = self.lib_generator.load_lib()
self.lib.init()
self._post_init()
def _forward_from_prebuild_lib(self, *args, stream=0):
def _process_dynamic_symbolic(self):
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
dynamic_symbolic_map = {}
for i, param in enumerate(params):
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map):
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args = [
ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args
]
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
def _warp_forward_from_prebuild_lib(self, *ins: List[torch.Tensor], stream=0):
def _warp_forward_from_prebuild_lib(self,
*ins: List[torch.Tensor],
stream: Optional[int] = None):
"""High-level wrapper for kernel execution.
Handles:
1. Input validation
2. Output tensor allocation
3. Dynamic shape resolution
4. CUDA stream management
Args:
ins: Input PyTorch tensors
stream: Optional CUDA stream for asynchronous execution
Returns:
Single tensor or list of tensors containing the kernel results
"""
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
......@@ -70,20 +129,28 @@ class CtypesKernelAdapter(BaseKernelAdapter):
ins_idx = 0
args = []
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
# tensor pointers
for i in range(len(self.params)):
if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
self._forward_from_prebuild_lib(*args)
# dynamic symbolics
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
args.append(ins[buffer_idx].shape[shape_idx])
# if stream is not None, we need to pass the stream to the library
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self._forward_from_prebuild_lib(*args, stream=stream)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
......@@ -91,4 +158,30 @@ class CtypesKernelAdapter(BaseKernelAdapter):
return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable:
"""Returns a PyTorch-compatible function wrapper for the kernel."""
return self._warp_forward_from_prebuild_lib
@property
def prim_func(self) -> tir.PrimFunc:
"""Returns the primary TIR function from the IR module."""
return retrieve_func_from_module(self.ir_module)
@property
def srcpath(self):
"""Returns the source path of the compiled library."""
return self.lib_generator.srcpath
@property
def libpath(self):
"""Returns the path to the compiled library."""
return self.lib_generator.libpath
@property
def lib_code(self):
"""Returns the code of the compiled library."""
return self.lib_generator.lib_code
@property
def is_dynamic(self):
"""Indicates whether the kernel handles dynamic shapes."""
return (self.dynamic_symbolic_map is not None and len(self.dynamic_symbolic_map) > 0)
......@@ -89,12 +89,12 @@ class TLCUDASourceWrapper(object):
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set = set()
dynamic_symbolic_set: List[str] = []
for param in prim_func.params:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var):
dynamic_symbolic_set.add(dim.name)
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set
def get_cuda_init_func(self):
......@@ -201,203 +201,6 @@ class TLCUDASourceWrapper(object):
raise ValueError("Cannot find primary function in the module.")
class TLCUDASourceWrapperWithDynamic(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
super().__init__(scheduled_ir_module, source, target)
def get_cuda_init_func(self):
# Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory
call_str = """"""
# Iterate over functions and their dynamic shared memory requirements
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(
function_name, dynamic_smem_buf)
# Define the init function that will set the attributes for each kernel
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def create_dispatch_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
# Find the location of the global kernel function in the code
index = match_global_kernel(code)
# Analyze the function declaration to prepare for argument extraction
dummy_declaration = code[index:].split(";")[0]
function_name = self.function_name
# Identify the start of the function body to insert arguments
index = code.index("{", index)
function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
# Format the argument definitions for function declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s: str, function_args):
# Extract and clean the function call arguments to match the declaration
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for match in matches:
match = re.sub(r"\d+", "", match) # Remove numbers
match = re.sub(r"_", "", match) # Remove underscores
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
return call_args
call_args = ", ".join(func_call_args(dummy_declaration, function_args))
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
last_range = 0
num_items = len(function_informations)
_call_str = """"""
for last_range, (function_name, info) in enumerate(function_informations.items()):
# Prepare block and grid configurations for kernel launches
block_info, grid_info = info["block_info"], info["grid_info"]
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]),
legalize_c(grid_info[1]),
legalize_c(grid_info[2]),
)
# Handle dynamic shared memory specification
smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"])
opt_shapes = info["opt_shapes"]
# Generate conditional kernel launch code based on dynamic symbolic ranges
(symbolic,) = list(dynamic_symbolic_set)
range_str = opt_shapes[symbolic]
if last_range == 0:
call_str = " if ({} == 0) return; \n".format(symbolic,)
call_str += " if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
symbolic,
range_str,
function_name,
grid_str,
block_str,
smem_str,
call_args,
)
else:
call_str = " else if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
symbolic,
range_str,
function_name,
grid_str,
block_str,
smem_str,
call_args,
)
if last_range == num_items - 1:
call_str += " else {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
function_name, grid_str, block_str, smem_str, call_args)
_call_str += call_str
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
return host_func
def parse_source_information(self):
# Parse device module to extract execution configurations for each function
device_mod = get_annotated_device_mod(self.mod, self.target, backend=self.backend)
block_info_map = {}
grid_info_map = {}
dynamic_smem_buf_map = {}
for g_var, func in device_mod.functions.items():
# Default block and grid configurations
block_info = [1, 1, 1]
grid_info = [1, 1, 1]
function_name = g_var.name_hint
attrs = func.attrs
dynamic_smem_buf = None
if "dyn_shared_memory_buf" in attrs:
dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs:
# Extract block and grid sizes from thread extents
thread_extent = attrs["thread_extent"]
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
grid_info["xyz".index(tag[-1])] = extent
# Map the extracted configurations to each function
block_info_map[function_name] = block_info
grid_info_map[function_name] = grid_info
dynamic_smem_buf_map[function_name] = dynamic_smem_buf
# Store the mappings for use in code generation
self.block_info = block_info_map
self.grid_info = grid_info_map
self.dynamic_smem_buf = dynamic_smem_buf_map
def update_lib_code(self, code: str):
# Organize function information for code generation
function_informations = {}
for g_var, func in self.mod.functions.items():
function_name = g_var.name_hint
# Do not update function with dispatch host function
if (function_name not in self.block_info) or (function_name not in self.grid_info):
continue
attrs = func.attrs
assert "opt_shapes" in attrs
opt_shapes = attrs["opt_shapes"]
function_informations[function_name] = {
"function_name": function_name,
"opt_shapes": opt_shapes,
"block_info": self.block_info[function_name],
"grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
}
def compare_map_objects(map_obj):
comparable_representation = list(map_obj.values())
return comparable_representation
function_informations = dict(
sorted(
function_informations.items(),
key=lambda item: compare_map_objects(item[1]["opt_shapes"]),
))
self.lib_code = code
# Generate the initialization and dispatch functions
init_func = self.get_cuda_init_func()
host_func = self.create_dispatch_func(code, function_informations)
# Concatenate source code with generated code segments
lib_code = self.source + init_func + host_func
return lib_code
class TLHIPSourceWrapper(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
......@@ -430,11 +233,10 @@ class TLWrapper(BaseWrapper):
self.scheduled_ir_module = scheduled_ir_module
# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str, is_dynamic: bool = False):
def wrap(self, c_source: str):
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target):
wrapper_class = (
TLCUDASourceWrapper if not is_dynamic else TLCUDASourceWrapperWithDynamic)
wrapper_class = TLCUDASourceWrapper
elif is_hip_target(self.target):
wrapper_class = TLHIPSourceWrapper
else:
......
......@@ -35,6 +35,10 @@ from .customize import (
from .builtin import * # noqa: F401
def symbolic(name: str, dtype: str = "int32"):
return tir.Var(name, dtype)
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
# If order is row, use rasterization2DRow, otherwise use rasterization2DColumn
# The panel size is the number of threads in a warp
......
......@@ -93,6 +93,16 @@ class Profiler(TorchDLPackKernelAdapter):
func = self.__call__
return func(*ins)
def determine_profiler(self,
func: Optional[Callable] = None,
profiler: Literal["torch", "tvm", "auto"] = "auto"):
if profiler == "auto":
if func is None or isinstance(func, tvm.runtime.Module):
return "tvm"
else:
return "torch"
return profiler
def do_bench(
self,
func: Optional[Callable] = None,
......@@ -103,11 +113,7 @@ class Profiler(TorchDLPackKernelAdapter):
profiler: Literal["torch", "tvm", "auto"] = "auto",
input_tensors: List[torch.Tensor] = None,
):
if func is None:
# set default value if not provided
func = self.mod
profiler = "tvm"
profiler = self.determine_profiler(func, profiler)
if profiler == "torch":
ins = self._get_inputs() if input_tensors is None else input_tensors
bench_func = partial(func, *ins)
......@@ -119,6 +125,9 @@ class Profiler(TorchDLPackKernelAdapter):
_n_repeat=n_repeat,
)
elif profiler == "tvm":
if func is None:
func = self.mod
assert isinstance(func, tvm.runtime.Module), "func should be a TVM module"
ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors)
target = "cuda"
......@@ -133,25 +142,6 @@ class Profiler(TorchDLPackKernelAdapter):
tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
# Transform Latency to ms
return time_evaluator(*tvm_inputs).mean * 1e3
elif profiler == "auto":
# TODO(lei): select appropriate profiler based on the function
# class
ins = self._get_inputs()
bench_func = partial(func, *ins)
torch_res = do_bench(
bench_func,
warmup=warmup,
rep=rep,
_n_warmup=n_warmup,
_n_repeat=n_repeat,
)
ins = self._get_inputs(with_output=True)
time_evaluator = self.mod.time_evaluator(
self.mod.entry_name, tvm.cuda(0), number=rep, repeat=n_repeat)
tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
tvm_res = time_evaluator(*tvm_inputs).mean * 1e3
return min(torch_res, tvm_res)
else:
raise ValueError(f"Unknown profiler: {profiler}")
......
......@@ -4,6 +4,8 @@
from tvm.tir import Buffer
from typing import List
from functools import reduce
from tvm import IRModule
from tvm.tir import PrimFunc
# Scope Checkers for TVM Buffers
# These utility functions check the memory scope of a given TVM buffer.
......@@ -89,3 +91,26 @@ def array_reduce(array: List[int]) -> int:
int: The reduced integer.
"""
return reduce(lambda x, y: x * y, array)
def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
"""
Retrieve the single PrimFunc from an IRModule.
Args:
ir_module (IRModule): The TVM IRModule to extract the function from.
The module should contain exactly one global function.
Returns:
PrimFunc: The single function contained in the module.
Raises:
ValueError: If ir_module is not an IRModule.
AssertionError: If the module contains more than one global function.
"""
if not isinstance(ir_module, IRModule):
raise ValueError("Not supported type: ", type(ir_module))
assert len(ir_module.get_global_vars()) == 1, (
"The optimized module should only have one global variable for default schedule.")
func = list(ir_module.functions.values())[0]
return func
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment