Commit 089cc0a7 authored by yuanjypku's avatar yuanjypku Committed by LeiWang1999
Browse files

[Feature] Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check (#481)



* Fix Device Consistency in Autotuner Threads and Add Manual Profiler Check

* lint fix

* Update example_mla_decode.py

* Update __init__.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 2297af9a
...@@ -11,6 +11,7 @@ from functools import wraps, partial ...@@ -11,6 +11,7 @@ from functools import wraps, partial
from typing import Callable, List, Literal, Any, Optional, Union from typing import Callable, List, Literal, Any, Optional, Union
from tqdm import tqdm from tqdm import tqdm
import logging import logging
import functools
from dataclasses import dataclass from dataclasses import dataclass
import concurrent.futures import concurrent.futures
import torch import torch
...@@ -61,6 +62,7 @@ class JITContext: ...@@ -61,6 +62,7 @@ class JITContext:
atol: float atol: float
max_mismatched_ratio: float max_mismatched_ratio: float
skip_check: bool skip_check: bool
manual_check_prog: Callable
cache_input_tensors: bool cache_input_tensors: bool
kernel: tilelang.JITKernel kernel: tilelang.JITKernel
supply_type: tilelang.TensorSupplyType supply_type: tilelang.TensorSupplyType
...@@ -104,6 +106,7 @@ class CompileArgs: ...@@ -104,6 +106,7 @@ class CompileArgs:
atol: float = 1e-2 atol: float = 1e-2
max_mismatched_ratio: float = 0.01 max_mismatched_ratio: float = 0.01
skip_check: bool = False skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto' target: Literal['auto', 'cuda', 'hip'] = 'auto'
""" """
...@@ -116,6 +119,7 @@ class CompileArgs: ...@@ -116,6 +119,7 @@ class CompileArgs:
atol: float = 1e-2 atol: float = 1e-2
max_mismatched_ratio: float = 0.01 max_mismatched_ratio: float = 0.01
skip_check: bool = False skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto' target: Literal['auto', 'cuda', 'hip'] = 'auto'
...@@ -162,6 +166,7 @@ class AutoTuner: ...@@ -162,6 +166,7 @@ class AutoTuner:
atol: float = 1e-2, atol: float = 1e-2,
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True, cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto'): target: Literal['auto', 'cuda', 'hip'] = 'auto'):
"""Set compilation arguments for the auto-tuner. """Set compilation arguments for the auto-tuner.
...@@ -175,6 +180,7 @@ class AutoTuner: ...@@ -175,6 +180,7 @@ class AutoTuner:
atol: Absolute tolerance for validation. atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio. max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation. skip_check: Whether to skip validation.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors. cache_input_tensors: Whether to cache input tensors.
target: Target platform. target: Target platform.
...@@ -190,6 +196,7 @@ class AutoTuner: ...@@ -190,6 +196,7 @@ class AutoTuner:
atol=atol, atol=atol,
max_mismatched_ratio=max_mismatched_ratio, max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check, skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors, cache_input_tensors=cache_input_tensors,
target=target) target=target)
...@@ -232,6 +239,7 @@ class AutoTuner: ...@@ -232,6 +239,7 @@ class AutoTuner:
atol=compile_args.atol, atol=compile_args.atol,
max_mismatched_ratio=compile_args.max_mismatched_ratio, max_mismatched_ratio=compile_args.max_mismatched_ratio,
skip_check=compile_args.skip_check, skip_check=compile_args.skip_check,
manual_check_prog=compile_args.manual_check_prog,
cache_input_tensors=compile_args.cache_input_tensors, cache_input_tensors=compile_args.cache_input_tensors,
kernel=kernel, kernel=kernel,
supply_type=compile_args.supply_type, supply_type=compile_args.supply_type,
...@@ -246,6 +254,7 @@ class AutoTuner: ...@@ -246,6 +254,7 @@ class AutoTuner:
kernel = jit_context.kernel kernel = jit_context.kernel
supply_type = jit_context.supply_type supply_type = jit_context.supply_type
skip_check = jit_context.skip_check skip_check = jit_context.skip_check
manual_check_prog = jit_context.manual_check_prog
cache_input_tensors = jit_context.cache_input_tensors cache_input_tensors = jit_context.cache_input_tensors
ref_prog = jit_context.ref_prog ref_prog = jit_context.ref_prog
supply_prog = jit_context.supply_prog supply_prog = jit_context.supply_prog
...@@ -291,6 +300,12 @@ class AutoTuner: ...@@ -291,6 +300,12 @@ class AutoTuner:
self.jit_input_tensors = jit_input_tensors_supply() self.jit_input_tensors = jit_input_tensors_supply()
if (not skip_check) and (ref_prog is not None): if (not skip_check) and (ref_prog is not None):
if manual_check_prog is not None:
profiler.manual_assert_close(
ref_prog,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
else:
profiler.assert_allclose( profiler.assert_allclose(
ref_prog, ref_prog,
input_tensors=self.jit_input_tensors, input_tensors=self.jit_input_tensors,
...@@ -323,9 +338,14 @@ class AutoTuner: ...@@ -323,9 +338,14 @@ class AutoTuner:
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = [] futures = []
future_to_index = {} future_to_index = {}
def device_wrapper(func, device, *config_arg):
torch.cuda.set_device(device)
return func(*config_arg)
for i, config_arg in enumerate(config_args): for i, config_arg in enumerate(config_args):
future = pool.submit( future = pool.submit(
self.jit_compile, functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
*config_arg, *config_arg,
) )
futures.append(future) futures.append(future)
...@@ -355,7 +375,9 @@ class AutoTuner: ...@@ -355,7 +375,9 @@ class AutoTuner:
# Because tma init may behave strangely with one thread # Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_context) # latency, ref_latency = target_fn(jit_context)
benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = benchmark_executor.submit(target_fn, jit_context) future = benchmark_executor.submit(
functools.partial(device_wrapper, target_fn, torch.cuda.current_device()),
jit_context)
latency, ref_latency = future.result(timeout=timeout) latency, ref_latency = future.result(timeout=timeout)
except concurrent.futures.TimeoutError: except concurrent.futures.TimeoutError:
logger.info( logger.info(
...@@ -434,6 +456,7 @@ def jit(out_idx: Optional[List[int]] = None, ...@@ -434,6 +456,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol: float = 1e-2, atol: float = 1e-2,
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True, cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable: target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
"""Just-In-Time compilation decorator for tilelang programs. """Just-In-Time compilation decorator for tilelang programs.
...@@ -447,6 +470,7 @@ def jit(out_idx: Optional[List[int]] = None, ...@@ -447,6 +470,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol: Absolute tolerance for output validation. atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements. max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks. skip_check: Whether to skip validation checks.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors for each compilation. cache_input_tensors: Whether to cache input tensors for each compilation.
target: Target platform ('auto', 'cuda', or 'hip'). target: Target platform ('auto', 'cuda', or 'hip').
...@@ -475,6 +499,7 @@ def jit(out_idx: Optional[List[int]] = None, ...@@ -475,6 +499,7 @@ def jit(out_idx: Optional[List[int]] = None,
atol=atol, atol=atol,
max_mismatched_ratio=max_mismatched_ratio, max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check, skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors, cache_input_tensors=cache_input_tensors,
kernel=kernel, kernel=kernel,
supply_type=supply_type, supply_type=supply_type,
......
...@@ -121,6 +121,37 @@ class Profiler: ...@@ -121,6 +121,37 @@ class Profiler:
ref_name="ref", ref_name="ref",
) )
def manual_assert_close(
self,
reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None,
manual_check_prog: Callable = None,
):
"""Validates kernel output against a reference implementation.
Args:
reference_program: Reference implementation to compare against
input_tensors: Optional pre-generated input tensors
atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison
max_mismatched_ratio: Maximum allowed ratio of mismatched elements
"""
ins = self._get_inputs() if input_tensors is None else input_tensors
ref_outs = reference_program(*ins)
torch.cuda.synchronize()
lib_outs = self.func(*ins)
torch.cuda.synchronize()
if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs]
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
elif ref_outs is None:
ref_outs = []
assert len(lib_outs) == len(ref_outs), f"{len(lib_outs)=} not equals to {len(ref_outs)=} !"
torch.set_printoptions(edgeitems=torch.inf)
manual_check_prog(lib_outs, ref_outs)
def assert_consistent(self, repeat=10): def assert_consistent(self, repeat=10):
"""Checks for kernel consistency across multiple runs. """Checks for kernel consistency across multiple runs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment