Commit 66c7f6a1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Updated autotune usage in the examples to align with the latest changes (#309)

* [Enhancement] Add support for CUDA architecture 8.9 in GEMM template

- Introduced conditional inclusion of "gemm_sm89.h" for CUDA architectures 8.9 and above, enhancing compatibility with newer hardware.
- This change ensures that the GEMM template can leverage optimizations specific to the 8.9 architecture, improving performance for users with compatible GPUs.

* lintfix

* [Refactor] Clean up includes in gemm_sm89.h

- Removed duplicate inclusion of "common.h" and added "cuda_fp8.h" for improved clarity and organization.
- This change enhances the maintainability of the code by ensuring that header files are included only once and in a logical order.

* [Enhancement] Improve KernelCache with in-memory caching and detailed docstrings

- Added an in-memory cache to the KernelCache class to enhance performance by reducing disk access.
- Updated the __new__ method to initialize the memory cache and added logic to check the cache before loading from disk.
- Enhanced docstrings across multiple methods to provide clearer explanations of parameters and return values, improving code readability and maintainability.
- Implemented a clear_cache method to clear both in-memory and disk caches, ensuring efficient cache management.

* lint fix

* typofix

* [Refactor] Update matmul and flashattn function calls to return structured results

- Modified the matmul and flashattn function calls to return a single object containing latency, configuration, and reference latency, improving code clarity and reducing the number of returned variables.
- Updated all relevant instances in benchmark and example scripts to accommodate the new return structure, ensuring consistent usage across the codebase.

* lint fix
parent 5802c01b
......@@ -157,15 +157,6 @@ def matmul(M, N, K, with_roller):
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_M",
"block_N",
"block_K",
"num_stages",
"thread_num",
"policy",
"enable_rasteration",
],
warmup=3,
rep=20,
)
......@@ -295,7 +286,10 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
best_latency, best_config, ref_latency = matmul(M, N, K, with_roller)
best_result = matmul(M, N, K, with_roller)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
# Print out the benchmark results
print(f"Best latency (s): {best_latency}")
......
......@@ -323,8 +323,10 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K
# Run autotuning
best_latency, best_config, ref_latency = matmul(M, N, K, in_dtype, out_dtype, accum_dtype,
with_roller)
best_result = matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_roller)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
# Print benchmark results
print(f"Best latency (s): {best_latency}")
......
......@@ -291,8 +291,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, ref_latency = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
best_result = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -259,8 +259,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -231,8 +231,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -218,8 +218,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -212,8 +212,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -217,8 +217,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -441,8 +441,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, groups, kv_seqlen, dim, tune=args.tune)
best_result = flashattn(batch, heads, groups, kv_seqlen, dim, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -21,9 +21,9 @@ def naive_gemv(
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn:
tn = T.get_thread_binding(0) # tn = threadIdx.x
......@@ -54,9 +54,9 @@ def naive_splitk_gemv(
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn:
tn = T.get_thread_binding(0)
......@@ -91,9 +91,9 @@ def splitk_gemv(
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
......@@ -131,9 +131,9 @@ def splitk_gemv_vectorized(
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
......@@ -171,9 +171,9 @@ def splitk_gemv_vectorized_tvm(
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
......@@ -227,10 +227,6 @@ def get_best_config(N, K):
@autotune(
configs=get_configs(),
keys=[
"BLOCK_N",
"reduce_threads",
],
warmup=3,
rep=20,
)
......@@ -253,9 +249,9 @@ def get_best_config(N, K):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
......@@ -319,7 +315,10 @@ if __name__ == "__main__":
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K)
print("Test passed!")
best_latency, best_config, ref_latency = get_best_config(N, K)
best_result = get_best_config(N, K)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = splitk_gemv_vectorized_tvm(N, K, *best_config)
kernel = tl.compile(kernel, out_idx=-1)
profiler = kernel.get_profiler()
......
......@@ -247,8 +247,11 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = chunk_scan_fwd(
best_result = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -184,8 +184,11 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = chunk_state_fwd(
best_result = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -8,7 +8,7 @@ import tilelang
from tilelang import tvm as tvm
import inspect
from functools import wraps, partial
from typing import Callable, List, Literal, Any
from typing import Callable, List, Literal, Any, Optional
from tqdm import tqdm
import logging
from dataclasses import dataclass
......@@ -340,7 +340,7 @@ class AutoTuner:
return self.run()
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> Callable:
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> AutotuneResult:
"""Decorator for auto-tuning tilelang programs.
Args:
......@@ -362,7 +362,7 @@ def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100)
return decorator
def jit(out_idx: List[int],
def jit(out_idx: Optional[List[int]] = None,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
......
......@@ -109,7 +109,7 @@ def jit(
def compile(
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
out_idx: Union[List[int], int, None] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
......
......@@ -23,15 +23,16 @@ class BaseKernelAdapter(ABC):
elif isinstance(result_idx, int):
if result_idx > len(params) or result_idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
)
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
elif isinstance(result_idx, list):
for i, idx in enumerate(result_idx):
if idx > len(params) or idx < -len(params):
if idx >= len(params) or idx <= -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}"
f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
)
if idx < 0:
result_idx[i] = len(params) + idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment