"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "24f74e72f11535bc73b0af9bea33be534fb21181"
Commit 0430cfe7 authored by Chaofan Lin's avatar Chaofan Lin Committed by LeiWang1999
Browse files

[Bugfix] Fix Benchmark/Example Code for Autotuning (#254)



* fix tune args

* lint

* Refactor gemm example and autotuner logging

- Updated `ref_program` in `example_gemm.py` to return the result of matrix multiplication instead of modifying an input parameter.
- Changed logging filename in `__init__.py` from 'out.log' to 'autotuner.log' for better clarity.
- Modified JIT kernel compilation process to include `out_idx` directly in the adapter creation, enhancing flexibility.
- Improved validation of `result_idx` in `BaseKernelAdapter` to ensure it falls within valid bounds.

* Refactor `ref_program` in `benchmark_matmul_intrinsic.py` to use the `@` operator for matrix multiplication instead of `torch.matmul`, simplifying the implementation by removing the unused parameter `C`.

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 60923344
......@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
def ref_program(A, B, C):
"""
A reference matrix multiplication program, used to compare performance.
......@@ -174,7 +174,6 @@ def matmul(M, N, K, with_roller):
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="auto",
target="auto",
)
def kernel(
......
import argparse
import logging
import torch
import torch.backends
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as tl
......@@ -161,7 +159,7 @@ def tl_matmul(
def ref_program(A, B):
"""Reference matrix multiplication program."""
return torch.matmul(A, B.T)
return A @ B.T
def get_configs(M, N, K, with_roller=False):
......@@ -271,7 +269,6 @@ def matmul(M,
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="auto",
target="auto",
)
def kernel(
......
......@@ -96,11 +96,7 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False):
keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[2],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
......
......@@ -240,11 +240,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10,
rep=10)
@jit(
out_idx=[2],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None,
block_N=None,
block_K=None,
......
......@@ -189,11 +189,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -161,11 +161,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -157,11 +157,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -154,11 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -159,11 +159,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -277,8 +277,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
out_idx=[6],
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
max_mismatched_ratio=0.05,
profiler="auto")
max_mismatched_ratio=0.05)
def kernel(block_N=None, block_H=None, num_split=None, num_stages=None, threads=None):
return kernel_func(block_N, block_H, num_split, num_stages, threads)
......
......@@ -9,8 +9,8 @@ from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
def ref_program(A, B, C):
C += A @ B.T
def ref_program(A, B):
return A @ B.T
def get_configs(M, N, K, with_roller=False):
......
......@@ -200,11 +200,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
keys=["block_M", "block_N", "block_K", "block_Dstate", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[7],
supply_type=tilelang.TensorSupplyType.Normal,
ref_prog=None,
profiler="auto")
@jit(out_idx=[7], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
def kernel(block_M=None,
block_N=None,
block_K=None,
......
......@@ -142,11 +142,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[4],
supply_type=tilelang.TensorSupplyType.Normal,
ref_prog=None,
profiler="auto")
@jit(out_idx=[4], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
......
......@@ -15,7 +15,7 @@ from functools import partial
logger = logging.getLogger(__name__)
logging.basicConfig(
filename='out.log',
filename='autotuner.log',
filemode='w',
level=logging.INFO,
format='%(asctime)s %(levelname)s:%(message)s')
......@@ -204,9 +204,9 @@ def jit(out_idx: List[int],
@wraps(fn)
def decorator(*args, **kwargs) -> float:
# Enabling Efficient Fusion
kernel = tilelang.compile(
fn(*args, **kwargs), target=target, pass_configs={"tir.merge_static_smem": True})
kernel = tilelang.compile(fn(*args, **kwargs), out_idx=out_idx, target=target)
profiler = kernel.get_profiler()
return JITContext(
......
......@@ -27,7 +27,15 @@ class BaseKernelAdapter(ABC):
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
elif not isinstance(result_idx, list):
elif isinstance(result_idx, list):
for i, idx in enumerate(result_idx):
if idx > len(params) or idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}"
)
if idx < 0:
result_idx[i] = len(params) + idx
else:
raise ValueError("result_idx should be a list of integers")
return result_idx
......
......@@ -69,7 +69,6 @@ class JITKernel(object):
from_database : bool, optional
Whether to create a TorchFunction from a database.
"""
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
......@@ -104,7 +103,7 @@ class JITKernel(object):
return
# Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func)
adapter = self._compile_and_create_adapter(func, out_idx)
# The adapter's function is assigned as the callable function for this instance.
self.adapter = adapter
......@@ -165,7 +164,8 @@ class JITKernel(object):
"""
return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc) -> BaseKernelAdapter:
def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
out_idx: List[int]) -> BaseKernelAdapter:
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
......@@ -182,7 +182,7 @@ class JITKernel(object):
verbose = self.verbose
target = self.target
target_host = self.target_host
out_idx = self.out_idx
execution_backend = self.execution_backend
pass_configs = self.pass_configs
......@@ -335,6 +335,10 @@ class JITKernel(object):
def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func)
@property
def out_idx(self) -> List[int]:
return self.adapter.result_idx
@property
def params(self) -> List[KernelParam]:
return self.artifact.params if self.artifact else self.adapter.params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment