"...composable_kernel_onnxruntime.git" did not exist on "76f3131939fb6bd0ed34cfac3be3b92c672b49e6"
Commit eec47592 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AutoTune] Support `with set_autotune_inputs` to set auto tuning input tensors (#632)

* [Refactor] Simplify and modularize autotuner implementation

- Removed unused imports and extensive code sections from the autotuner module to enhance readability and maintainability.
- Modularized the code by introducing new imports for autotuning and capturing functionalities, streamlining the overall structure.
- Improved logging setup and removed redundant timeout handling functions, focusing on core autotuning logic.
- Updated the AutoTuner class to better utilize the new modular structure, ensuring efficient performance during auto-tuning processes.

* [Refactor] Clean up and enhance capture and tuner modules

- Improved code readability by removing unnecessary blank lines and organizing imports in `capture.py` and `tuner.py`.
- Enhanced logging in the `AutoTuner` class to provide clearer warnings regarding the usage of `supply_prog` in the context of auto-tuning.
- Streamlined the `CaptureStack` class for better thread-local context management.

* lint fix

* [Refactor] Simplify configuration and autotuning logic in blocksparse GEMM example

- Updated `get_configs` function to reduce the number of configurations, enhancing performance and clarity.
- Removed the `get_best_config` function, integrating its logic directly into the `blocksparse_matmul` function with the `@autotune` decorator for streamlined autotuning.
- Adjusted the main function to directly utilize the autotuned kernel, simplifying the overall structure and improving readability.
- Deleted obsolete test file for autotuning decorator, cleaning up the codebase.

* [Refactor] Improve code formatting and readability in autotune test file

- Reformatted the `matmul` function and `get_configs` function for better readability by adjusting line breaks and indentation.
- Fixed a typo in the `enable_rasteration` parameter name to ensure consistency.
- Cleaned up unnecessary blank lines to enhance overall code clarity.

* Update example_blocksparse_gemm.py

* Update capture.py
parent b6fe9582
......@@ -2,7 +2,6 @@ import argparse
import itertools
import tilelang
import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.engine.param import KernelParam
from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType
import torch
......@@ -34,7 +33,7 @@ print(f"Target Block Sparsity: {sparsity}")
print(f"Using Autotuner: {use_autotune}\n")
def get_configs(M, N, K):
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
......@@ -90,55 +89,7 @@ def supply_program(params: List[KernelParam]):
return input_tensors
def get_best_config(M, N, K):
# Define the kernel function to be tuned.
# Parameters like block_M, block_N, etc., are tuned by the AutoTuner.
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None):
return blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num,
enable_rasteration)
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K)
).set_compile_args(
out_idx=[-1], # Index of the output tensor
target="auto", # Automatically detect target
).set_profile_args(
# supply_type should not set here because we provide a custom supply
# function `supply_prog` and `supply_type` will be ignored.
# supply_prog: Provide the custom function to generate input tensors
# (A, B, BlockMask) for the kernel, allowing controlling sparsity via
# BlockMask generation.
supply_prog=supply_program,
# ref_prog: Using dense matmul (A @ B) as a placeholder reference.
# The 'correct' block-sparse reference (`ref_program` above) requires
# block_M, block_N, block_K parameters. However, these parameters are
# part of the configuration being *tuned* by the AutoTuner and cannot
# be fixed inputs to a static `ref_prog` function signature.
# This dense matmul serves only as a performance baseline.
ref_prog=lambda A, B, BlockMask: A @ B,
# skip_check: Set to True because the provided `ref_prog` does not
# compute the correct result for the block-sparse kernel.
skip_check=True,
# cache_input_tensors: Set to False because the shape of the BlockMask tensor
# (dependent on block_M, block_N, block_K being tuned) changes between
# different configurations. Reusing cached tensors from a previous
# configuration would lead to shape mismatches.
cache_input_tensors=False,
)
# Run the tuning process
return autotuner.run(warmup=3, rep=20)
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M,
N,
......@@ -192,22 +143,16 @@ def main():
# Run the autotuner to find the best kernel configuration and performance
# get_best_config is expected to return an object containing the compiled kernel,
# the best configuration found, latency, and reference latency.
result = get_best_config(M, N, K)
kernel = blocksparse_matmul(M, N, K)
# Extract results from the autotuner run
kernel = result.kernel
best_config = result.config
block_M = best_config[0]
block_N = best_config[1]
block_K = best_config[2]
best_latency = result.latency
ref_latency = result.ref_latency
best_config = kernel.config
best_latency = kernel.latency
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[
"block_K"]
print(f"Best Config: {best_config}")
print(f"Block Dimensions (BM, BN, BK): ({block_M}, {block_N}, {block_K})")
print(f"Sparsity Ratio: {sparsity}")
print(f"Best Kernel Latency: {best_latency:.6f} ms")
print(f"Reference Latency: {ref_latency:.6f} ms")
else:
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
......
import torch
import argparse
import itertools
import tilelang as tl
from tilelang.autotuner import *
import tilelang
import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate
......@@ -165,7 +164,7 @@ def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False):
out_idx=[2],
target="auto",
).set_profile_args(
supply_type=tl.TensorSupplyType.Integer,
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=ref_prog,
skip_check=False,
)
......@@ -299,9 +298,9 @@ def main(n: int = 128,
kernel = result.kernel
else:
config = get_heuristic_config()
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
kernel = tilelang.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_prog)
profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2)
......
import itertools
import logging
import tilelang.testing
import tilelang.language as T
from tilelang.autotuner import autotune
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
# Simple TIR Compute Expression
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = 0
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64]
block_N = [64]
block_K = [32]
num_stages = [0, 1]
thread_num = [128]
enable_rasterization = [False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
)
@tilelang.jit(out_idx=[-1],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return kernel()
def test_autotune_get_configs():
get_configs(8192, 8192, 8192, with_roller=True)
get_configs(8192, 8192, 8192, with_roller=False)
def test_autotune_matmul():
matmul(8192, 8192, 8192, with_roller=True)
matmul(8192, 8192, 8192, with_roller=False)
if __name__ == "__main__":
tilelang.testing.main()
import itertools
import logging
import tilelang
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs
import tilelang.language as T
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs():
iter_params = dict(
block_M=[64],
block_N=[64],
block_K=[32],
num_stages=[0, 1],
thread_num=[128],
enable_rasterization=[False])
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
block_M=128,
block_N=128,
block_K=32,
num_stages=0,
thread_num=128,
enable_rasteration=False):
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_autotune(M: int, N: int, K: int):
import torch
a = torch.randn(M, K, dtype=torch.float16).cuda()
b = torch.randn(N, K, dtype=torch.float16).cuda()
with set_autotune_inputs([a, b]):
kernel = matmul(M, N, K)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def test_autotune_matmul():
run_autotune(8192, 8192, 8192)
if __name__ == "__main__":
tilelang.testing.main()
This diff is collapsed.
import threading
from typing import List, Any, Optional
# Use thread local to store the stack
# This is to avoid the cross-thread interference
_local = threading.local()
class CaptureStack:
"""
A simple stack implementation for capturing items in a thread-local context.
Used to manage a stack of items (e.g., input tensors) for auto-tuning capture.
"""
def __init__(self):
# Initialize an empty list to use as the stack
self.stack = []
def push(self, item):
"""
Push an item onto the top of the stack.
Args:
item: The item to be pushed onto the stack.
"""
self.stack.append(item)
def pop(self):
"""
Pop and return the top item from the stack.
Returns:
The item at the top of the stack.
Raises:
IndexError: If the stack is empty.
"""
return self.stack.pop()
def top(self):
"""
Return the item at the top of the stack without removing it.
Returns:
The item at the top of the stack.
Raises:
IndexError: If the stack is empty.
"""
return self.stack[-1]
def size(self):
"""
Return the number of items in the stack.
Returns:
int: The size of the stack.
"""
return len(self.stack)
def __len__(self):
"""
Return the number of items in the stack (len operator support).
Returns:
int: The size of the stack.
"""
return len(self.stack)
def __bool__(self):
"""
Return True if the stack is not empty, False otherwise.
Returns:
bool: Whether the stack contains any items.
"""
return bool(self.stack)
def _get_current_stack() -> CaptureStack:
if not hasattr(_local, "capture_stack"):
_local.capture_stack = CaptureStack()
return _local.capture_stack
class AutotuneInputsCapture:
__slots__ = ("tensors")
def __init__(self, tensors: List[Any]):
self.tensors = tensors
def __enter__(self) -> None:
_get_current_stack().push(self)
def __exit__(self, exc_type, exc_val, exc_tb):
_get_current_stack().pop()
def set_autotune_inputs(*args) -> AutotuneInputsCapture:
"""Set input tensors for auto-tuning.
This function creates a context manager for capturing input tensors
during the auto-tuning process. It supports both:
set_autotune_inputs(a, b, c)
set_autotune_inputs([a, b, c])
Args:
*args: Either a single list/tuple of tensors, or multiple tensor arguments.
Returns:
AutotuneInputsCapture: A context manager for auto-tuning inputs.
"""
if len(args) == 1 and isinstance(args[0], (list, tuple)):
tensors = list(args[0])
else:
tensors = list(args)
return AutotuneInputsCapture(tensors)
def get_autotune_inputs() -> Optional[List[Any]]:
"""
Get the current autotune inputs from the stack.
"""
stack = _get_current_stack()
return stack.top().tensors if stack else None
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment