Unverified Commit 3b21a67d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[AMD][MLA] Fix mla autotune for rocm (#861)

* Refactor matmul example to include ReLU activation and update batch size in benchmark script

* lint fix

* Enhance autotuning capabilities in benchmark script and update argument defaults

- Introduced a new `get_configs` function to generate autotuning configurations for the benchmark.
- Updated the default batch size and kv context length in the argument parser for improved performance.
- Renamed the `--auto_tune` argument to `--autotune` for consistency.
- Modified the kernel invocation logic to support autotuning based on the new configurations.

* lint fix
parent b9a51c43
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
from einops import rearrange, einsum from einops import rearrange, einsum
import argparse import argparse
...@@ -9,6 +8,24 @@ import argparse ...@@ -9,6 +8,24 @@ import argparse
tilelang.disable_cache() tilelang.disable_cache()
def get_configs():
import itertools
BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32]
threads = [128, 256]
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads))
return [{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"threads": c[3],
} for c in _configs]
@tilelang.autotune(configs=get_configs())
@tilelang.jit( @tilelang.jit(
out_idx=[6], pass_configs={ out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
...@@ -273,16 +290,16 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -273,16 +290,16 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number') parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length') parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim') parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument('--auto_tune', action='store_true', help='auto tune') parser.add_argument('--autotune', action='store_true', help='auto tune')
args = parser.parse_args() args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
enable_autotune = args.auto_tune enable_autotune = args.autotune
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim pv_flops = 2 * batch * heads * kv_ctx * dim
...@@ -290,9 +307,22 @@ if __name__ == "__main__": ...@@ -290,9 +307,22 @@ if __name__ == "__main__":
BLOCK_N = 32 BLOCK_N = 32
BLOCK_H = 64 BLOCK_H = 64
num_split = 4 num_split = 4
threads = 128
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, if enable_autotune:
num_split) kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
else:
kernel = flashmla_decode(
batch,
heads,
kv_heads,
kv_ctx,
dim,
pe_dim,
BLOCK_N,
BLOCK_H,
num_split,
threads=threads)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs() input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors) tilelang_output = kernel(*input_tensors)
...@@ -303,35 +333,3 @@ if __name__ == "__main__": ...@@ -303,35 +333,3 @@ if __name__ == "__main__":
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms") print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops") print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
# Enable Auto Tuning
def get_configs():
import itertools
BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32]
thread_num = [128, 256]
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, thread_num))
return [{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"thread_num": c[3],
} for c in _configs]
def wrapped_kernel(block_N=None, block_H=None, num_split=None, thread_num=None):
return flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, block_N, block_H,
num_split, thread_num)
if enable_autotune:
autotuner = AutoTuner.from_kernel(kernel=wrapped_kernel, configs=get_configs())
tune_result = autotuner.run(warmup=3, rep=20)
best_latency = tune_result.latency
best_config = tune_result.config
print(f"Best latency: {best_latency} ms")
print(f"Best TFlops: {total_flops / best_latency * 1e-9} TFlops")
print(f"Best config: {best_config}")
...@@ -104,6 +104,7 @@ class AutoTuner: ...@@ -104,6 +104,7 @@ class AutoTuner:
profile_args = ProfileArgs() profile_args = ProfileArgs()
_kernel_parameters: Optional[Tuple[str, ...]] = None _kernel_parameters: Optional[Tuple[str, ...]] = None
_function_parameters: Optional[Dict[str, Any]] = None
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner" cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
...@@ -222,9 +223,10 @@ class AutoTuner: ...@@ -222,9 +223,10 @@ class AutoTuner:
return self return self
def set_kernel_parameters(self, parameters: Tuple[str, ...]): def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
# for cache key generation # for cache key generation
self._kernel_parameters = parameters self._kernel_parameters = k_parameters
self._function_parameters = f_parameters
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]: def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process. """Generate a cache key for the auto-tuning process.
...@@ -417,8 +419,15 @@ class AutoTuner: ...@@ -417,8 +419,15 @@ class AutoTuner:
key_args_tuple, key_kwargs_tuple = self._kernel_parameters key_args_tuple, key_kwargs_tuple = self._kernel_parameters
tunable_arguments = [key for key, _ in top_config.items()] tunable_arguments = [key for key, _ in top_config.items()]
def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool:
params_list = list(parameters.keys())
assert key in params_list, f"Tunable argument {key} not found in function parameters"
return params_list.index(key) < len(key_args_tuple)
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple): if any(key in top_config for key, _ in key_kwargs_tuple) or any(
check_tunable_argument_value(key, self._function_parameters, key_args_tuple)
for key in tunable_arguments):
logger.warning( logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
) )
...@@ -676,7 +685,7 @@ class _AutoTunerImplementation: ...@@ -676,7 +685,7 @@ class _AutoTunerImplementation:
) )
autotuner.jit_compile = jit_compile autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key) autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters)
autotuner.run = partial(autotuner.run, warmup, rep, timeout) autotuner.run = partial(autotuner.run, warmup, rep, timeout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment