Commit 92e8d5f4 authored by Haodong Tian's avatar Haodong Tian Committed by LeiWang1999
Browse files

[Bugfix] Resolve autotuner bugs for blocksparse GEMM example (#300)

* [Bugfix] Configure autotuner specific logger for correct level handling
- Previously, logging relied on basicConfig, which configured the root logger. This caused the named autotuner logger to ignore DEBUG messages.
- This commit sets up a dedicated logger for autotuner, correctly route DEBUG messages to 'autotuner.log' and INFO+ messages to the console.

* [Bugfix] Fix tensor_supply for boolean type
- Previously `get_tensor_supply` used `torch.randint(-2, 3)` as a fallback, which caused error when the dtype was `torch.bool`.
- This commits adds an `is_boolean` check in `KernelParam` and updates `get_tensor_supply` to specifically use `torch.randint(0, 2)` for boolean dtypes.

* [Bugfix] Always regenerate JIT inputs during tuning
- Removes the caching for `self.jit_input_tensors` within `AutoTuner`. When different autotuning configurations can alter the required input tensor shapes or other properties, reusing cached inputs from a previous configuration lead to errors or incorrect assessments.
- This change ensures that `profiler._get_inputs()` is called unconditionally for each configuration evaluation. Since `_get_inputs` is assumed to be relatively inexpensive, the potential overhead is considered acceptable.

* [Example] Update example_blocksparse_gemm for autotuner

* Run code formatter

* [Feature] Enable custom tensor supply and input caching control in Autotuner
- Previously, tensor generation was tied to `supply_type` and input caching behavior across configurations was less explicit/controlled.
- This commit introduces a `supply_prog` parameter to allow providing a custom function for generating input tensors, overriding the default mechanism.
- Adds a `cache_input_tensors` flag (default True) to control input tensor caching:
    - If True, tensors are generated once per configuration and reused for repetitions, with a check for potential shape mismatches between configurations.
    - If False, tensors are regenerated for every configuration trial.
- Refactors internal input tensor handling using supplier functions for clarity.
- Adds a `check_tensor_list_compatibility` utility for shape comparison.

* [Example] Update example_blocksparse_gemm for autotuner

* Run code formatter

* [Example] Small fix in example_blocksparse_gemm

* [Fix] Raise error if autotuning yields no valid configuration
parent 1873dc00
...@@ -2,8 +2,36 @@ import argparse ...@@ -2,8 +2,36 @@ import argparse
import itertools import itertools
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.engine.param import KernelParam
from tilelang.utils.tensor import get_tensor_supply
import torch import torch
from tilelang.autotuner import autotune, jit from typing import List
DEFAULT_BLOCK_M = 128
DEFAULT_BLOCK_N = 128
DEFAULT_BLOCK_K = 32
DEFAULT_NUM_STAGES = 2
DEFAULT_THREAD_NUM = 128
DEFAULT_ENABLE_RASTERIZATION = True
parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument(
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
sparsity = args.sparsity
use_autotune = args.use_autotune
default_tensor_supply = get_tensor_supply()
print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}")
print(f"Target Block Sparsity: {sparsity}")
print(f"Using Autotuner: {use_autotune}\n")
def get_configs(M, N, K): def get_configs(M, N, K):
...@@ -27,30 +55,45 @@ def get_configs(M, N, K): ...@@ -27,30 +55,45 @@ def get_configs(M, N, K):
} for c in _configs] } for c in _configs]
def ref_program(A, B, BlockMask, C): def ref_program(A, B, BlockMask, block_M, block_N, block_K):
batch_M = A.shape[0] // block_M ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
batch_N = B.shape[1] // block_N for i in range(M // block_M):
batch_K = A.shape[1] // block_K for j in range(N // block_N):
for i in range(batch_M):
for j in range(batch_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(batch_K): for k in range(K // block_K):
if BlockMask[i, j, k]: if BlockMask[i, j, k]:
accu += A[i*block_M:(i+1)*block_M, k*block_K:(k+1)*block_K].to(torch.float32) @ \ accu += (
B[k*block_K:(k+1)*block_K, j*block_N:(j+1)*block_N].to(torch.float32) A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
C[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = accu.to(torch.float16) torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
return ref_c
def supply_program(params: List[KernelParam]):
input_tensors = []
for p in params:
# Check if the kernel parameter is BlockMask tensor.
# Here, BlockMask is uniquely identified by having 3 dimensions.
if len(p.shape) != 3:
# For non-BlockMask tensors, use the default tensor generation logic.
input_tensors.append(default_tensor_supply(p))
else:
# For BlockMask tensor, randomly set elements to True based on desired
# sparsity level.
block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device())
block_mask[:, :, :] = torch.rand(p.shape) > sparsity
input_tensors.append(block_mask)
return input_tensors
def get_best_config(M, N, K): def get_best_config(M, N, K):
@autotune( # Define the kernel function to be tuned.
configs=get_configs(M, N, K), # Parameters like block_M, block_N, etc., are tuned by the AutoTuner.
keys=["block_M", "block_N", "block_K", "num_stages", "thread_num", "enable_rasteration"],
warmup=3,
rep=20,
)
@jit(out_idx=[-1], ref_prog=ref_program)
def kernel(block_M=None, def kernel(block_M=None,
block_N=None, block_N=None,
block_K=None, block_K=None,
...@@ -60,7 +103,40 @@ def get_best_config(M, N, K): ...@@ -60,7 +103,40 @@ def get_best_config(M, N, K):
return blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, return blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num,
enable_rasteration) enable_rasteration)
return kernel() autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K)
).set_compile_args(
out_idx=[-1], # Index of the output tensor
# supply_type should not set here because we provide a custom supply
# function `supply_prog` and `supply_type` will be ignored.
# supply_prog: Provide the custom function to generate input tensors
# (A, B, BlockMask) for the kernel, allowing controlling sparsity via
# BlockMask generation.
supply_prog=supply_program,
# ref_prog: Using dense matmul (A @ B) as a placeholder reference.
# The 'correct' block-sparse reference (`ref_program` above) requires
# block_M, block_N, block_K parameters. However, these parameters are
# part of the configuration being *tuned* by the AutoTuner and cannot
# be fixed inputs to a static `ref_prog` function signature.
# This dense matmul serves only as a performance baseline.
ref_prog=lambda A, B, BlockMask: A @ B,
# skip_check: Set to True because the provided `ref_prog` does not
# compute the correct result for the block-sparse kernel.
skip_check=True,
# cache_input_tensors: Set to False because the shape of the BlockMask tensor
# (dependent on block_M, block_N, block_K being tuned) changes between
# different configurations. Reusing cached tensors from a previous
# configuration would lead to shape mismatches.
cache_input_tensors=False,
target="auto", # Automatically detect target
)
# Run the tuning process
return autotuner.run(warmup=3, rep=20)
def blocksparse_matmul(M, def blocksparse_matmul(M,
...@@ -106,47 +182,52 @@ def blocksparse_matmul(M, ...@@ -106,47 +182,52 @@ def blocksparse_matmul(M,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument(
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args = parser.parse_args() # Initialize input matrices A and B on the GPU with half precision
M, N, K = args.m, args.n, args.k
# Initialize input matrices
a = torch.randn(M, K).cuda().half() a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half() b = torch.randn(K, N).cuda().half()
if args.use_autotune: if args.use_autotune:
best_latency, best_config, ref_latency = get_best_config(M, N, K) # Run the autotuner to find the best kernel configuration and performance
func = blocksparse_matmul(M, N, K, *best_config) # get_best_config is expected to return an object containing the compiled kernel,
# the best configuration found, latency, and reference latency.
result = get_best_config(M, N, K)
# Extract results from the autotuner run
kernel = result.kernel
best_config = result.config
block_M = best_config[0]
block_N = best_config[1]
block_K = best_config[2]
best_latency = result.latency
ref_latency = result.ref_latency
print(f"Best Config: {best_config}")
print(f"Block Dimensions (BM, BN, BK): ({block_M}, {block_N}, {block_K})")
print(f"Sparsity Ratio: {sparsity}")
print(f"Best Kernel Latency: {best_latency:.6f} ms")
print(f"Reference Latency: {ref_latency:.6f} ms")
else: else:
func = blocksparse_matmul(M, N, K, 128, 128, 32, 2, 128, True) func = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
DEFAULT_ENABLE_RASTERIZATION)
kernel = tilelang.compile(func, out_idx=-1)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity # Create block mask with desired sparsity
block_M, block_N, block_K = 128, 128, 32 # default values if not using autotune
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > args.sparsity block_mask = torch.rand(mask_shape).cuda() > sparsity
kernel = tilelang.compile(func, out_idx=-1) # Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask) c = kernel(a, b, block_mask)
# Verify result # Compute the reference result using the naive PyTorch implementation
ref_c = torch.zeros_like(c) ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
for i in range(M // block_M):
for j in range(N // block_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=a.device)
for k in range(K // block_K):
if block_mask[i, j, k]:
accu += (
a[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ b[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
try:
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("✅ Results are close! Verification successful.")
except AssertionError as e:
print("❌ Verification FAILED: Results differ significantly.")
print(e)
...@@ -13,15 +13,28 @@ from tqdm import tqdm ...@@ -13,15 +13,28 @@ from tqdm import tqdm
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
import concurrent.futures import concurrent.futures
import torch
import os import os
import sys
# Configure logging for the autotuner module
# TODO: Consider creating a common logger in utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.propagate = False
logging.basicConfig( formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
filename='autotuner.log',
filemode='w', file_handler = logging.FileHandler('autotuner.log', mode='w')
level=logging.DEBUG, file_handler.setLevel(logging.DEBUG)
format='%(asctime)s %(levelname)s:%(message)s') file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -30,22 +43,24 @@ class JITContext: ...@@ -30,22 +43,24 @@ class JITContext:
Attributes: Attributes:
out_idx: List of output tensor indices. out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation. ref_prog: Reference program for correctness validation.
supply_prog: Supply program for input tensors.
rtol: Relative tolerance for output validation. rtol: Relative tolerance for output validation.
atol: Absolute tolerance for output validation. atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements. max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks. skip_check: Whether to skip validation checks.
cache_input_tensors: Whether to cache input tensors for each compilation.
profiler: Profiler instance for performance measurement. profiler: Profiler instance for performance measurement.
target: Target platform ('cuda' or 'hip'). target: Target platform ('cuda' or 'hip').
""" """
out_idx: List[int] out_idx: List[int]
supply_type: tilelang.TensorSupplyType
ref_prog: Callable ref_prog: Callable
supply_prog: Callable
rtol: float rtol: float
atol: float atol: float
max_mismatched_ratio: float max_mismatched_ratio: float
skip_check: bool skip_check: bool
cache_input_tensors: bool
profiler: tilelang.Profiler profiler: tilelang.Profiler
target: Literal['cuda', 'hip'] target: Literal['cuda', 'hip']
...@@ -103,40 +118,51 @@ class AutoTuner: ...@@ -103,40 +118,51 @@ class AutoTuner:
def set_compile_args(self, def set_compile_args(self,
out_idx: List[int], out_idx: List[int],
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal, supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None, ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2, rtol: float = 1e-2,
atol: float = 1e-2, atol: float = 1e-2,
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto'): target: Literal['auto', 'cuda', 'hip'] = 'auto'):
"""Set compilation arguments for the auto-tuner. """Set compilation arguments for the auto-tuner.
Args: Args:
out_idx: List of output tensor indices. out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism. supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
ref_prog: Reference program for validation. ref_prog: Reference program for validation.
supply_prog: Supply program for input tensors.
rtol: Relative tolerance for validation. rtol: Relative tolerance for validation.
atol: Absolute tolerance for validation. atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio. max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation. skip_check: Whether to skip validation.
cache_input_tensors: Whether to cache input tensors.
target: Target platform. target: Target platform.
Returns: Returns:
AutoTuner: Self for method chaining. AutoTuner: Self for method chaining.
""" """
# If a custom `supply_prog`` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
if ref_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `set_compile_args` because "
"`ref_prog` is not None.")
def _compile(*config_arg): def _compile(*config_arg):
kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target) kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target)
profiler = kernel.get_profiler() profiler = kernel.get_profiler(tensor_supply_type=supply_type)
jit_context = JITContext( jit_context = JITContext(
out_idx=out_idx, out_idx=out_idx,
supply_type=supply_type,
ref_prog=ref_prog, ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
max_mismatched_ratio=max_mismatched_ratio, max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check, skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
profiler=profiler, profiler=profiler,
target=target) target=target)
return jit_context return jit_context
...@@ -163,28 +189,64 @@ class AutoTuner: ...@@ -163,28 +189,64 @@ class AutoTuner:
best_config = None best_config = None
best_jit_context = None best_jit_context = None
def target_fn(jit_context): def target_fn(jit_context: JITContext):
# Unpack the context # Unpack the context
profiler = jit_context.profiler profiler = jit_context.profiler
skip_check = jit_context.skip_check skip_check = jit_context.skip_check
cache_input_tensors = jit_context.cache_input_tensors
ref_prog = jit_context.ref_prog ref_prog = jit_context.ref_prog
supply_prog = jit_context.supply_prog
rtol = jit_context.rtol rtol = jit_context.rtol
atol = jit_context.atol atol = jit_context.atol
max_mismatched_ratio = jit_context.max_mismatched_ratio max_mismatched_ratio = jit_context.max_mismatched_ratio
self.jit_input_tensors = profiler._get_inputs( # Factory functions for generating input tensors.
with_output=profiler == # This encapsulates the logic of using either a custom supply program (`supply_prog`)
"tvm") if self.jit_input_tensors is None else self.jit_input_tensors # or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
else:
return profiler._get_inputs(with_output=with_output)
return func
jit_input_tensors_supply = get_input_tensors_supply(with_output=(profiler == "tvm"))
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors:
jit_input_tensors = jit_input_tensors_supply()
if self.jit_input_tensors is not None:
if not check_tensor_list_compatibility(self.jit_input_tensors,
jit_input_tensors):
logger.warning(
"Incompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.\n"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
self.jit_input_tensors = jit_input_tensors
self.jit_input_tensors = jit_input_tensors
else:
self.jit_input_tensors = jit_input_tensors_supply()
if (not skip_check) and (ref_prog is not None): if (not skip_check) and (ref_prog is not None):
profiler.assert_allclose( profiler.assert_allclose(
ref_prog, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio) ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench( latency = profiler.do_bench(
profiler.func, n_warmup=warmup, n_repeat=rep, input_tensors=self.jit_input_tensors) profiler.func, n_warmup=warmup, n_repeat=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None: if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = profiler._get_inputs( self.ref_input_tensors = ref_input_tensors_supply()
with_output=False) if self.ref_input_tensors is None else self.ref_input_tensors
self.ref_latency_cache = profiler.do_bench( self.ref_latency_cache = profiler.do_bench(
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
...@@ -223,8 +285,9 @@ class AutoTuner: ...@@ -223,8 +285,9 @@ class AutoTuner:
try: try:
result = future.result() result = future.result()
results_with_configs.append((result, config)) results_with_configs.append((result, config))
except Exception: except Exception as e:
logger.debug(f"Compilation failed for config {config} at index {idx}") logger.debug(
f"Compilation failed for config {config} at index {idx} with error: {e}")
continue continue
ref_latency = None ref_latency = None
...@@ -253,6 +316,13 @@ class AutoTuner: ...@@ -253,6 +316,13 @@ class AutoTuner:
tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}") tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")
pool.shutdown() pool.shutdown()
if best_jit_context is None:
error_msg = ("Auto-tuning failed: No configuration successfully "
"compiled and passed benchmarking/validation.")
logger.error(error_msg)
raise RuntimeError(error_msg)
return AutotuneResult( return AutotuneResult(
latency=best_latency, latency=best_latency,
config=best_config, config=best_config,
...@@ -293,49 +363,79 @@ def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) ...@@ -293,49 +363,79 @@ def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100)
def jit(out_idx: List[int], def jit(out_idx: List[int],
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal, supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None, ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2, rtol: float = 1e-2,
atol: float = 1e-2, atol: float = 1e-2,
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable: target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
"""Just-In-Time compilation decorator for tilelang programs. """Just-In-Time compilation decorator for tilelang programs.
Args: Args:
out_idx: List of output tensor indices. out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism. supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
ref_prog: Reference program for correctness validation. ref_prog: Reference program for correctness validation.
supply_prog: Supply program for input tensors.
rtol: Relative tolerance for output validation. rtol: Relative tolerance for output validation.
atol: Absolute tolerance for output validation. atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements. max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks. skip_check: Whether to skip validation checks.
cache_input_tensors: Whether to cache input tensors for each compilation.
target: Target platform ('auto', 'cuda', or 'hip'). target: Target platform ('auto', 'cuda', or 'hip').
Returns: Returns:
Callable: Decorated function that performs JIT compilation. Callable: Decorated function that performs JIT compilation.
""" """
# If a custom `supply_prog`` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `autotune.jit` because "
"`supply_prog` is not None.")
def wrapper(fn: Callable): def wrapper(fn: Callable):
@wraps(fn) @wraps(fn)
def decorator(*args, **kwargs) -> float: def decorator(*args, **kwargs) -> float:
kernel = tilelang.compile(fn(*args, **kwargs), out_idx=out_idx, target=target) kernel = tilelang.compile(fn(*args, **kwargs), out_idx=out_idx, target=target)
profiler = kernel.get_profiler(tensor_supply_type=supply_type)
profiler = kernel.get_profiler()
return JITContext( return JITContext(
out_idx=out_idx, out_idx=out_idx,
supply_type=supply_type,
ref_prog=ref_prog, ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
max_mismatched_ratio=max_mismatched_ratio, max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check, skip_check=skip_check,
cache_input_tensors=cache_input_tensors,
profiler=profiler, profiler=profiler,
target=target) target=target)
return decorator return decorator
return wrapper return wrapper
def check_tensor_list_compatibility(
list1: List[torch.Tensor],
list2: List[torch.Tensor],
) -> bool:
"""Checks if two lists of tensors are compatible.
Compatibility checks performed include:
1. Lists have the same length.
2. Corresponding tensors have the same shape.
Args:
list1: First list of tensors.
list2: Second list of tensors.
"""
if len(list1) != len(list2):
return False
return all(tensor1.shape == tensor2.shape for tensor1, tensor2 in zip(list1, list2))
...@@ -83,6 +83,15 @@ class KernelParam: ...@@ -83,6 +83,15 @@ class KernelParam:
""" """
return str(self.dtype).removeprefix("torch.").startswith("float8") return str(self.dtype).removeprefix("torch.").startswith("float8")
def is_boolean(self) -> bool:
"""
Checks if the parameter represents a boolean type.
Returns:
bool: True if parameter is a boolean type, False otherwise
"""
return str(self.dtype).removeprefix("torch.").startswith("bool")
@dataclass @dataclass
class CompiledArtifact: class CompiledArtifact:
......
...@@ -66,9 +66,17 @@ class Profiler: ...@@ -66,9 +66,17 @@ class Profiler:
ins.append(self.supply(self.params[i])) ins.append(self.supply(self.params[i]))
return ins return ins
def _get_params(self, with_output=False):
params = []
for i in range(len(self.params)):
if with_output or i not in self.result_idx:
params.append(self.params[i])
return params
def assert_allclose( def assert_allclose(
self, self,
reference_program: Callable, reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None,
atol: float = 1e-2, atol: float = 1e-2,
rtol: float = 1e-2, rtol: float = 1e-2,
max_mismatched_ratio=0.01, max_mismatched_ratio=0.01,
...@@ -77,11 +85,12 @@ class Profiler: ...@@ -77,11 +85,12 @@ class Profiler:
Args: Args:
reference_program: Reference implementation to compare against reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors
atol: Absolute tolerance for comparison atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison rtol: Relative tolerance for comparison
max_mismatched_ratio: Maximum allowed ratio of mismatched elements max_mismatched_ratio: Maximum allowed ratio of mismatched elements
""" """
ins = self._get_inputs() ins = self._get_inputs() if input_tensors is None else input_tensors
ref_outs = reference_program(*ins) ref_outs = reference_program(*ins)
torch.cuda.synchronize() torch.cuda.synchronize()
lib_outs = self.func(*ins) lib_outs = self.func(*ins)
......
...@@ -71,11 +71,14 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -71,11 +71,14 @@ def get_tensor_supply(supply_type: TensorSupplyType):
if supply_type == TensorSupplyType.Auto: if supply_type == TensorSupplyType.Auto:
is_unsigned = param.is_unsigned() is_unsigned = param.is_unsigned()
is_float8 = param.is_float8() is_float8 = param.is_float8()
is_boolean = param.is_boolean()
if is_unsigned: if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8: elif is_float8:
return torch.randint( return torch.randint(
low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif is_boolean:
return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype)
elif dtype in {torch.float16, torch.float32, torch.bfloat16}: elif dtype in {torch.float16, torch.float32, torch.bfloat16}:
return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
else: else:
...@@ -90,11 +93,14 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -90,11 +93,14 @@ def get_tensor_supply(supply_type: TensorSupplyType):
if supply_type == TensorSupplyType.Integer: if supply_type == TensorSupplyType.Integer:
is_unsigned = param.is_unsigned() is_unsigned = param.is_unsigned()
is_float8 = param.is_float8() is_float8 = param.is_float8()
is_boolean = param.is_boolean()
if is_unsigned: if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8: elif is_float8:
return torch.randint( return torch.randint(
low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif is_boolean:
return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype)
else: else:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform: elif supply_type == TensorSupplyType.Uniform:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment