"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "632df3ea2f99f3c8e4d2a16fab6ebe4303609da1"
Commit 7171aff6 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Autotune] Introduce cache mechanism for auto tuner (#527)

* [Enhancement] Add commit ID to versioning and improve logging initialization

* Updated `get_tilelang_version` to include an optional commit ID in the version string.
* Enhanced the `TileLangBuilPydCommand` to write the version with commit ID to the VERSION file during the build process.
* Introduced a new function `get_git_commit_id` in `version.py` to retrieve the current git commit hash.
* Refactored logger initialization in `autotuner/__init__.py` to ensure handlers are set up only once, improving performance and clarity.
* Minor fixes in `flatten_buffer.cc` and `kernel_cache.py` for better handling of versioning and logging.

* [Refactor] Enhance AutoTuner and JITKernel for improved performance and caching

* Refactored the AutoTuner class to include new methods for setting compilation and profiling arguments, enhancing configurability.
* Introduced caching mechanisms for tuning results, allowing for faster retrieval of previously computed configurations.
* Updated JITKernel to store tuning results, including latency and configuration details, improving the kernel's performance tracking.
* Added new methods for generating cache keys and saving/loading results to/from disk, streamlining the tuning process.
* Enhanced the overall structure and readability of the autotuning logic, ensuring better maintainability and clarity.
* Minor adjustments in related modules to support the new caching and profiling features.

* [Refactor] Clean up code formatting and improve readability in AutoTuner and related modules

* Consolidated import statements and removed unnecessary line breaks for better readability.
* Standardized function argument formatting across the AutoTuner and CompileArgs classes.
* Enhanced consistency in the use of whitespace and indentation throughout the codebase.
* Minor adjustments in the Profiler and JITKernel classes to improve clarity and maintainability.
* Ensured that all changes adhere to the project's coding style guidelines.

* [Refactor] Remove redundant type hints in AutoTuner modules

* Simplified import statements in `__init__.py` and `param.py` by removing unnecessary duplicate type hints for `Any`.
* Improved code readability and maintainability by streamlining type imports across the AutoTuner module.

* [Refactor] Update AutoTuner configuration for improved profiling and target detection

* Enhanced the AutoTuner configuration across multiple examples by adding `set_profile_args` to better manage profiling settings.
* Standardized the use of `target="auto"` in compile arguments to ensure automatic target detection.
* Removed redundant target specifications in certain instances to streamline the configuration process.
* Improved overall clarity and maintainability of the autotuning logic in various example scripts.

* [Refactor] Simplify code formatting and improve readability in example scripts

* Consolidated function argument formatting in `benchmark_mla_decode_amd_tilelang.py`, `example_elementwise_add.py`, and `performance.py` for better clarity.
* Removed unnecessary line breaks and standardized argument placement across multiple files.
* Enhanced overall code readability and maintainability in autotuning examples and performance scripts.

* [Refactor] Update JIT decorator usage across multiple files

* Removed redundant parameters from the JIT decorator in various benchmark and example scripts, simplifying the code.
* Standardized the import of the JIT decorator from `tilelang`, enhancing consistency across the codebase.
* Improved overall readability and maintainability by consolidating import statements and cleaning up function definitions.

* [Refactor] Standardize JIT decorator formatting across benchmark and example scripts

* Simplified the formatting of the JIT decorator in multiple files by removing unnecessary line breaks.
* Enhanced code readability and consistency in the usage of the JIT decorator across benchmark and example scripts.
* Improved overall maintainability by ensuring uniformity in function definitions and decorator usage.
parent 09581e4e
...@@ -2,10 +2,9 @@ import argparse ...@@ -2,10 +2,9 @@ import argparse
import itertools import itertools
import logging import logging
import tilelang as tl
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import autotune, jit from tilelang.autotuner import autotune
from tilelang import jit
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
...@@ -160,13 +159,7 @@ def matmul(M, N, K, with_roller): ...@@ -160,13 +159,7 @@ def matmul(M, N, K, with_roller):
warmup=3, warmup=3,
rep=20, rep=20,
) )
@jit( @jit(out_idx=[2],)
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
target="auto",
)
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
......
...@@ -8,7 +8,7 @@ from tilelang.intrinsics import get_swizzle_layout ...@@ -8,7 +8,7 @@ from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.autotuner import autotune, jit from tilelang.autotuner import autotune
import itertools import itertools
# Configure logger # Configure logger
...@@ -264,13 +264,7 @@ def matmul(M, ...@@ -264,13 +264,7 @@ def matmul(M,
warmup=3, warmup=3,
rep=5, rep=5,
) )
@jit( @tl.jit(out_idx=[2],)
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
target="auto",
)
def kernel( def kernel(
block_row_warps=None, block_row_warps=None,
block_col_warps=None, block_col_warps=None,
......
...@@ -2,9 +2,9 @@ import argparse ...@@ -2,9 +2,9 @@ import argparse
import itertools import itertools
import logging import logging
import tilelang as tl
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import autotune, jit from tilelang.autotuner import autotune
from tilelang import jit
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -160,13 +160,7 @@ def matmul(M, N, K, with_roller): ...@@ -160,13 +160,7 @@ def matmul(M, N, K, with_roller):
warmup=3, warmup=3,
rep=20, rep=20,
) )
@jit( @jit(out_idx=[2],)
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
target="auto",
)
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
......
...@@ -107,7 +107,8 @@ def get_best_config(M, N, K): ...@@ -107,7 +107,8 @@ def get_best_config(M, N, K):
kernel=kernel, configs=get_configs(M, N, K) kernel=kernel, configs=get_configs(M, N, K)
).set_compile_args( ).set_compile_args(
out_idx=[-1], # Index of the output tensor out_idx=[-1], # Index of the output tensor
target="auto", # Automatically detect target
).set_profile_args(
# supply_type should not set here because we provide a custom supply # supply_type should not set here because we provide a custom supply
# function `supply_prog` and `supply_type` will be ignored. # function `supply_prog` and `supply_type` will be ignored.
...@@ -133,7 +134,6 @@ def get_best_config(M, N, K): ...@@ -133,7 +134,6 @@ def get_best_config(M, N, K):
# different configurations. Reusing cached tensors from a previous # different configurations. Reusing cached tensors from a previous
# configuration would lead to shape mismatches. # configuration would lead to shape mismatches.
cache_input_tensors=False, cache_input_tensors=False,
target="auto", # Automatically detect target
) )
# Run the tuning process # Run the tuning process
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
......
...@@ -325,11 +325,7 @@ if __name__ == "__main__": ...@@ -325,11 +325,7 @@ if __name__ == "__main__":
num_split, thread_num) num_split, thread_num)
if enable_autotune: if enable_autotune:
autotuner = AutoTuner.from_kernel( autotuner = AutoTuner.from_kernel(kernel=wrapped_kernel, configs=get_configs())
kernel=wrapped_kernel, configs=get_configs()).set_compile_args(
supply_type=tilelang.TensorSupplyType.Integer,
target="auto",
)
tune_result = autotuner.run(warmup=3, rep=20) tune_result = autotuner.run(warmup=3, rep=20)
best_latency = tune_result.latency best_latency = tune_result.latency
best_config = tune_result.config best_config = tune_result.config
......
...@@ -240,7 +240,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -240,7 +240,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"], keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @tilelang.jit(out_idx=[2])
def kernel(block_M=None, def kernel(block_M=None,
block_N=None, block_N=None,
block_K=None, block_K=None,
......
...@@ -13,8 +13,8 @@ def ref_program(x, y): ...@@ -13,8 +13,8 @@ def ref_program(x, y):
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func @T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
out_dtype)): (M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N start_x = bx * block_N
start_y = by * block_M start_y = by * block_M
...@@ -42,13 +42,15 @@ def get_best_config(M, N): ...@@ -42,13 +42,15 @@ def get_best_config(M, N):
autotuner = AutoTuner.from_kernel( autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N)).set_compile_args( kernel=kernel, configs=get_configs(M, N)).set_compile_args(
out_idx=[-1], out_idx=[-1],
target="cuda",
).set_profile_args(
supply_type=tilelang.TensorSupplyType.Auto, supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=False, skip_check=False,
target="cuda",
) )
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512) parser.add_argument("--m", type=int, default=512)
......
...@@ -189,7 +189,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -189,7 +189,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
keys=["block_M", "block_N", "num_stages", "threads"], keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -161,7 +161,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -161,7 +161,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
keys=["block_M", "block_N", "num_stages", "threads"], keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10, warmup=10,
rep=10) rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -6,6 +6,7 @@ import tilelang.language as T ...@@ -6,6 +6,7 @@ import tilelang.language as T
import itertools import itertools
import argparse import argparse
from functools import partial from functools import partial
from tilelang import jit
def get_configs(): def get_configs():
...@@ -153,7 +154,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -153,7 +154,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
if tune: if tune:
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -6,6 +6,7 @@ import tilelang.language as T ...@@ -6,6 +6,7 @@ import tilelang.language as T
import itertools import itertools
import argparse import argparse
from functools import partial from functools import partial
from tilelang import jit
def get_configs(): def get_configs():
...@@ -158,7 +159,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -158,7 +159,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
if tune: if tune:
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -150,7 +150,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -150,7 +150,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
if tune: if tune:
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -155,7 +155,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -155,7 +155,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
if tune: if tune:
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -118,10 +118,11 @@ def get_best_config(M, N, K, with_roller=False): ...@@ -118,10 +118,11 @@ def get_best_config(M, N, K, with_roller=False):
autotuner = AutoTuner.from_kernel( autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
out_idx=[-1], out_idx=[-1],
target="auto",
).set_profile_args(
supply_type=tl.TensorSupplyType.Integer, supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=False, skip_check=False,
target="auto",
) )
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
......
...@@ -3,7 +3,8 @@ import itertools ...@@ -3,7 +3,8 @@ import itertools
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
from tvm import DataType from tvm import DataType
from tilelang.autotuner import autotune, jit from tilelang.autotuner import autotune
from tilelang import jit
def ref_program(A, B): def ref_program(A, B):
...@@ -232,9 +233,6 @@ def get_best_config(N, K): ...@@ -232,9 +233,6 @@ def get_best_config(N, K):
) )
@jit( @jit(
out_idx=[-1], out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto", target="auto",
) )
def kernel( def kernel(
...@@ -317,7 +315,7 @@ def main(): ...@@ -317,7 +315,7 @@ def main():
best_result = get_best_config(N, K) best_result = get_best_config(N, K)
best_config = best_result.config best_config = best_result.config
kernel = splitk_gemv_vectorized_tvm(N, K, *best_config) kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
kernel = tl.compile(kernel, out_idx=-1) kernel = tl.compile(kernel, out_idx=-1)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500) latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
......
...@@ -196,7 +196,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -196,7 +196,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
if tune: if tune:
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[7], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None) @tilelang.jit(out_idx=[7])
def kernel(block_M=None, def kernel(block_M=None,
block_N=None, block_N=None,
block_K=None, block_K=None,
......
...@@ -138,7 +138,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -138,7 +138,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
if tune: if tune:
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[4], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None) @tilelang.jit(out_idx=[4])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads) return kernel_func(block_M, block_N, block_K, num_stages, threads)
......
import argparse import argparse
import tilelang as tl
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner from tilelang.autotuner import AutoTuner
...@@ -64,11 +63,9 @@ def run(M, N, K): ...@@ -64,11 +63,9 @@ def run(M, N, K):
autotuner = AutoTuner.from_kernel( autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs()).set_compile_args( kernel=kernel, configs=get_configs()).set_compile_args(
out_idx=[-1], out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto", target="auto",
) ).set_profile_args(
ref_prog=ref_program,)
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
......
import itertools import itertools
import logging import logging
import tilelang as tl
import tilelang.testing import tilelang.testing
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner from tilelang.autotuner import AutoTuner
...@@ -251,11 +250,9 @@ def matmul(M, N, K, with_roller): ...@@ -251,11 +250,9 @@ def matmul(M, N, K, with_roller):
autotuner = AutoTuner.from_kernel( autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
out_idx=[-1], out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto", target="auto",
) ).set_profile_args(
ref_prog=ref_program,)
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
......
import itertools import itertools
import logging import logging
import tilelang as tl
import tilelang.testing import tilelang.testing
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import jit, autotune from tilelang.autotuner import autotune
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -148,13 +147,7 @@ def matmul(M, N, K, with_roller): ...@@ -148,13 +147,7 @@ def matmul(M, N, K, with_roller):
warmup=3, warmup=3,
rep=20, rep=20,
) )
@jit( @tilelang.jit(out_idx=[-1],)
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
target="auto",
)
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment