Commit 089cc0a7 authored by yuanjypku's avatar yuanjypku Committed by LeiWang1999
Browse files

[Feature] Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check (#481)



* Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check

* lint fix

* Update example_mla_decode.py

* Update __init__.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 2297af9a
......@@ -11,6 +11,7 @@ from functools import wraps, partial
from typing import Callable, List, Literal, Any, Optional, Union
from tqdm import tqdm
import logging
import functools
from dataclasses import dataclass
import concurrent.futures
import torch
......@@ -61,6 +62,7 @@ class JITContext:
atol: float
max_mismatched_ratio: float
skip_check: bool
manual_check_prog: Callable
cache_input_tensors: bool
kernel: tilelang.JITKernel
supply_type: tilelang.TensorSupplyType
......@@ -104,6 +106,7 @@ class CompileArgs:
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'
"""
......@@ -116,6 +119,7 @@ class CompileArgs:
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'
......@@ -162,6 +166,7 @@ class AutoTuner:
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto'):
"""Set compilation arguments for the auto-tuner.
......@@ -175,6 +180,7 @@ class AutoTuner:
atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors.
target: Target platform.
......@@ -190,6 +196,7 @@ class AutoTuner:
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
target=target)
......@@ -232,6 +239,7 @@ class AutoTuner:
atol=compile_args.atol,
max_mismatched_ratio=compile_args.max_mismatched_ratio,
skip_check=compile_args.skip_check,
manual_check_prog=compile_args.manual_check_prog,
cache_input_tensors=compile_args.cache_input_tensors,
kernel=kernel,
supply_type=compile_args.supply_type,
......@@ -246,6 +254,7 @@ class AutoTuner:
kernel = jit_context.kernel
supply_type = jit_context.supply_type
skip_check = jit_context.skip_check
manual_check_prog = jit_context.manual_check_prog
cache_input_tensors = jit_context.cache_input_tensors
ref_prog = jit_context.ref_prog
supply_prog = jit_context.supply_prog
......@@ -291,12 +300,18 @@ class AutoTuner:
self.jit_input_tensors = jit_input_tensors_supply()
if (not skip_check) and (ref_prog is not None):
profiler.assert_allclose(
ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
if manual_check_prog is not None:
profiler.manual_assert_close(
ref_prog,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
else:
profiler.assert_allclose(
ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
......@@ -323,9 +338,14 @@ class AutoTuner:
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = []
future_to_index = {}
def device_wrapper(func, device, *config_arg):
torch.cuda.set_device(device)
return func(*config_arg)
for i, config_arg in enumerate(config_args):
future = pool.submit(
self.jit_compile,
functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
*config_arg,
)
futures.append(future)
......@@ -355,7 +375,9 @@ class AutoTuner:
# Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_context)
benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = benchmark_executor.submit(target_fn, jit_context)
future = benchmark_executor.submit(
functools.partial(device_wrapper, target_fn, torch.cuda.current_device()),
jit_context)
latency, ref_latency = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
logger.info(
......@@ -434,6 +456,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
"""Just-In-Time compilation decorator for tilelang programs.
......@@ -447,6 +470,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors for each compilation.
target: Target platform ('auto', 'cuda', or 'hip').
......@@ -475,6 +499,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
kernel=kernel,
supply_type=supply_type,
......
......@@ -121,6 +121,37 @@ class Profiler:
ref_name="ref",
)
def manual_assert_close(
self,
reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None,
manual_check_prog: Callable = None,
):
"""Validates kernel output against a reference implementation.
Args:
reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors
atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison
max_mismatched_ratio: Maximum allowed ratio of mismatched elements
"""
ins = self._get_inputs() if input_tensors is None else input_tensors
ref_outs = reference_program(*ins)
torch.cuda.synchronize()
lib_outs = self.func(*ins)
torch.cuda.synchronize()
if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs]
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
elif ref_outs is None:
ref_outs = []
assert len(lib_outs) == len(ref_outs), f"{len(lib_outs)=} not equals to {len(ref_outs)=} !"
torch.set_printoptions(edgeitems=torch.inf)
manual_check_prog(lib_outs, ref_outs)
def assert_consistent(self, repeat=10):
"""Checks for kernel consistency across multiple runs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment