"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "9f7bac4c1c21d259c59f44114554256b39c3610b"
Commit 0430cfe7 authored by Chaofan Lin's avatar Chaofan Lin Committed by LeiWang1999
Browse files

[Bugfix] Fix Benchmark/Example Code for Autotuning (#254)



* fix tune args

* lint

* Refactor gemm example and autotuner logging

- Updated `ref_program` in `example_gemm.py` to return the result of matrix multiplication instead of modifying an input parameter.
- Changed logging filename in `__init__.py` from 'out.log' to 'autotuner.log' for better clarity.
- Modified JIT kernel compilation process to include `out_idx` directly in the adapter creation, enhancing flexibility.
- Improved validation of `result_idx` in `BaseKernelAdapter` to ensure it falls within valid bounds.

* Refactor `ref_program` in `benchmark_matmul_intrinsic.py` to use the `@` operator for matrix multiplication instead of `torch.matmul`, simplifying the implementation by removing the unused parameter `C`.

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 60923344
...@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) ...@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def ref_program(A, B): def ref_program(A, B, C):
""" """
A reference matrix multiplication program, used to compare performance. A reference matrix multiplication program, used to compare performance.
...@@ -174,7 +174,6 @@ def matmul(M, N, K, with_roller): ...@@ -174,7 +174,6 @@ def matmul(M, N, K, with_roller):
supply_type=tl.TensorSupplyType.Integer, supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=True, skip_check=True,
profiler="auto",
target="auto", target="auto",
) )
def kernel( def kernel(
......
import argparse import argparse
import logging import logging
import torch
import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType from tvm import DataType
import tilelang as tl import tilelang as tl
...@@ -161,7 +159,7 @@ def tl_matmul( ...@@ -161,7 +159,7 @@ def tl_matmul(
def ref_program(A, B): def ref_program(A, B):
"""Reference matrix multiplication program.""" """Reference matrix multiplication program."""
return torch.matmul(A, B.T) return A @ B.T
def get_configs(M, N, K, with_roller=False): def get_configs(M, N, K, with_roller=False):
...@@ -271,7 +269,6 @@ def matmul(M, ...@@ -271,7 +269,6 @@ def matmul(M,
supply_type=tl.TensorSupplyType.Integer, supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=True, skip_check=True,
profiler="auto",
target="auto", target="auto",
) )
def kernel( def kernel(
......
...@@ -96,11 +96,7 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False): ...@@ -96,11 +96,7 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False):
keys=["block_M", "block_N", "block_K", "num_stages", "threads"], keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
out_idx=[2],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads) return kernel_func(block_M, block_N, block_K, num_stages, threads)
......
...@@ -240,11 +240,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -240,11 +240,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"], keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
out_idx=[2],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, def kernel(block_M=None,
block_N=None, block_N=None,
block_K=None, block_K=None,
......
...@@ -189,11 +189,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -189,11 +189,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
keys=["block_M", "block_N", "num_stages", "threads"], keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -161,11 +161,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -161,11 +161,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
keys=["block_M", "block_N", "num_stages", "threads"], keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -157,11 +157,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -157,11 +157,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
keys=["block_M", "block_N", "num_stages", "threads"], keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -154,11 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -154,11 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
keys=["block_M", "block_N", "num_stages", "threads"], keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -159,11 +159,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -159,11 +159,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
keys=["block_M", "block_N", "num_stages", "threads"], keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -277,8 +277,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -277,8 +277,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
out_idx=[6], out_idx=[6],
supply_type=tilelang.TensorSupplyType.Auto, supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program, ref_prog=ref_program,
max_mismatched_ratio=0.05, max_mismatched_ratio=0.05)
profiler="auto")
def kernel(block_N=None, block_H=None, num_split=None, num_stages=None, threads=None): def kernel(block_N=None, block_H=None, num_split=None, num_stages=None, threads=None):
return kernel_func(block_N, block_H, num_split, num_stages, threads) return kernel_func(block_N, block_H, num_split, num_stages, threads)
......
...@@ -9,8 +9,8 @@ from tilelang.carver.arch import CUDA ...@@ -9,8 +9,8 @@ from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
def ref_program(A, B, C): def ref_program(A, B):
C += A @ B.T return A @ B.T
def get_configs(M, N, K, with_roller=False): def get_configs(M, N, K, with_roller=False):
......
...@@ -200,11 +200,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -200,11 +200,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
keys=["block_M", "block_N", "block_K", "block_Dstate", "num_stages", "threads"], keys=["block_M", "block_N", "block_K", "block_Dstate", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[7], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
out_idx=[7],
supply_type=tilelang.TensorSupplyType.Normal,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, def kernel(block_M=None,
block_N=None, block_N=None,
block_K=None, block_K=None,
......
...@@ -142,11 +142,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -142,11 +142,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
keys=["block_M", "block_N", "block_K", "num_stages", "threads"], keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit( @jit(out_idx=[4], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
out_idx=[4],
supply_type=tilelang.TensorSupplyType.Normal,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads) return kernel_func(block_M, block_N, block_K, num_stages, threads)
......
...@@ -15,7 +15,7 @@ from functools import partial ...@@ -15,7 +15,7 @@ from functools import partial
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.basicConfig( logging.basicConfig(
filename='out.log', filename='autotuner.log',
filemode='w', filemode='w',
level=logging.INFO, level=logging.INFO,
format='%(asctime)s %(levelname)s:%(message)s') format='%(asctime)s %(levelname)s:%(message)s')
...@@ -204,9 +204,9 @@ def jit(out_idx: List[int], ...@@ -204,9 +204,9 @@ def jit(out_idx: List[int],
@wraps(fn) @wraps(fn)
def decorator(*args, **kwargs) -> float: def decorator(*args, **kwargs) -> float:
# Enabling Efficient Fusion
kernel = tilelang.compile( kernel = tilelang.compile(fn(*args, **kwargs), out_idx=out_idx, target=target)
fn(*args, **kwargs), target=target, pass_configs={"tir.merge_static_smem": True})
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
return JITContext( return JITContext(
......
...@@ -27,7 +27,15 @@ class BaseKernelAdapter(ABC): ...@@ -27,7 +27,15 @@ class BaseKernelAdapter(ABC):
if result_idx < 0: if result_idx < 0:
result_idx = len(params) + result_idx result_idx = len(params) + result_idx
result_idx = [result_idx] result_idx = [result_idx]
elif not isinstance(result_idx, list): elif isinstance(result_idx, list):
for i, idx in enumerate(result_idx):
if idx > len(params) or idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}"
)
if idx < 0:
result_idx[i] = len(params) + idx
else:
raise ValueError("result_idx should be a list of integers") raise ValueError("result_idx should be a list of integers")
return result_idx return result_idx
......
...@@ -69,7 +69,6 @@ class JITKernel(object): ...@@ -69,7 +69,6 @@ class JITKernel(object):
from_database : bool, optional from_database : bool, optional
Whether to create a TorchFunction from a database. Whether to create a TorchFunction from a database.
""" """
self.out_idx = out_idx
self.execution_backend = execution_backend self.execution_backend = execution_backend
self.target = target self.target = target
self.target_host = target_host self.target_host = target_host
...@@ -104,7 +103,7 @@ class JITKernel(object): ...@@ -104,7 +103,7 @@ class JITKernel(object):
return return
# Compile the TileLang function and create a kernel adapter for execution. # Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func) adapter = self._compile_and_create_adapter(func, out_idx)
# The adapter's function is assigned as the callable function for this instance. # The adapter's function is assigned as the callable function for this instance.
self.adapter = adapter self.adapter = adapter
...@@ -165,7 +164,8 @@ class JITKernel(object): ...@@ -165,7 +164,8 @@ class JITKernel(object):
""" """
return self.torch_function(*args, **kwds) return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc) -> BaseKernelAdapter: def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
out_idx: List[int]) -> BaseKernelAdapter:
""" """
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter. Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
...@@ -182,7 +182,7 @@ class JITKernel(object): ...@@ -182,7 +182,7 @@ class JITKernel(object):
verbose = self.verbose verbose = self.verbose
target = self.target target = self.target
target_host = self.target_host target_host = self.target_host
out_idx = self.out_idx
execution_backend = self.execution_backend execution_backend = self.execution_backend
pass_configs = self.pass_configs pass_configs = self.pass_configs
...@@ -335,6 +335,10 @@ class JITKernel(object): ...@@ -335,6 +335,10 @@ class JITKernel(object):
def run_once(self, func: Optional[Callable] = None) -> None: def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func) return self.get_profiler().run_once(func)
@property
def out_idx(self) -> List[int]:
return self.adapter.result_idx
@property @property
def params(self) -> List[KernelParam]: def params(self) -> List[KernelParam]:
return self.artifact.params if self.artifact else self.adapter.params return self.artifact.params if self.artifact else self.adapter.params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment