Commit 3de9f13c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Introduce KernelParam integration across modules (#223)

* [Refactor] Update KernelParam integration across modules

- Replaced instances of TensorType with KernelParam in various modules to standardize parameter handling.
- Updated JITKernel, BaseKernelAdapter, and CythonKernelAdapter to utilize KernelParam for improved type consistency.
- Enhanced Profiler class to include KernelParam in its parameters, ensuring better integration with the new parameter structure.
- Adjusted tensor handling in utility functions to accommodate the new KernelParam type, improving overall code clarity and maintainability.
- Updated copyright headers to reflect the correct organization.

* [Refactor] Clean up whitespace in kernel, profiler, and tensor modules

- Added blank lines for improved readability in kernel.py, __init__.py, and tensor.py.
- Enhanced code clarity by ensuring consistent formatting across these modules.

* [Enhancement] Add detailed docstrings to KernelParam and Profiler classes

- Enhanced KernelParam class with comprehensive docstrings for better understanding of its purpose and methods.
- Updated Profiler class to include detailed docstrings for its attributes and methods, improving code documentation and usability.
- Removed unused do_bench function to streamline the profiler module and improve clarity.

* [Refactor] Update type hints in do_bench function and clean up whitespace in profiler module

- Changed type hints for grad_to_none and quantiles parameters in do_bench function to use Optional for better clarity.
- Added a blank line in __init__.py for improved readability and consistency in the profiler module.

* [Refactor] Update type hint in do_bench function for consistency

- Changed the return type hint in the do_bench function from a union type to a more explicit List type for better clarity and consistency in type annotations.

* [Refactor] Update return type hint in do_bench function for clarity

- Changed the return type hint in the do_bench function from a union type to Union[float, List[float]] for improved clarity and consistency in type annotations.

* [Enhancement] Add func property to Profiler class for adapter access

- Introduced a new property `func` in the Profiler class to provide access to the adapter, ensuring that the adapter is set before retrieval. This enhancement improves the usability of the Profiler class by allowing easier access to the adapter functionality.

* [Refactor] Update kernel compilation and profiling in tests

- Replaced instances of `TL.lower` and `TL.Profiler` with `tilelang.compile` and the new profiler interface across multiple test files.
- Enhanced the kernel compilation process to utilize the updated API, improving consistency and maintainability in the testing framework.
- Updated assertions to use the new profiler methods for better clarity and functionality in performance testing.

* [Refactor] Simplify kernel invocation and remove unused parameters in tests

- Updated the kernel invocation in `test_tilelang_dynamic_symbolic.py` to directly assign the result to `C`, improving clarity.
- Removed the `execution_backend` parameter from `tilelang.compile` calls in `test_tilelang_jit_callback.py` and `test_tilelang_jit_gemm.py` for consistency with the updated API.
- Commented out the call to `tilelang.testing.main()` in `test_tilelang_jit_callback.py` and replaced it with a direct call to `test_gemm_jit_kernel()` to streamline test execution.
- Adjusted the dtype mapping in `TorchDLPackKernelAdapter` to use the parameter's dtype directly, enhancing code simplicity.

* [Refactor] Remove unused imports in test files for cleaner code

- Eliminated unnecessary imports of `tilelang` as `TL` in various test files to enhance code clarity and maintainability.
- Updated multiple test files to streamline the codebase and reduce potential confusion from unused references.

* [Refactor] Simplify kernel invocation in tilelang kernel test

- Updated the kernel invocation in `test_tilelang_kernel_bf16_gemm_mma.py` to directly assign the result to `C`, enhancing code clarity and consistency with recent changes in the API.

* [Refactor] Simplify kernel invocation in tilelang kernel tests

- Updated kernel invocations in multiple test files to directly assign the result to `C`, improving code clarity and consistency with the updated API.
- Removed unnecessary initialization of `C` as a zero tensor, streamlining the code further.

* [Refactor] Update kernel invocation in tilelang transform tests

- Replaced the use of `TL.Profiler` with `tilelang.compile` in `test_tilelang_transform_simplify.py`, enhancing code clarity and consistency with the updated API.
- Streamlined the kernel invocation process by directly assigning the result to `C`, improving readability and maintainability of the test code.
parent 6bc8d6d3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .lower import lower, is_device_call # noqa: F401 from .lower import lower, is_device_call # noqa: F401
from .param import KernelParam # noqa: F401
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
import os import os
import os.path as osp import os.path as osp
from typing import Union, Optional, Callable from typing import Union, Optional, Callable, List
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import tir, relay from tvm import tir
from tvm.ir import CallingConv from tvm.ir import CallingConv
from tvm.target import Target from tvm.target import Target
from tilelang.contrib import hipcc, nvcc from tilelang.contrib import hipcc, nvcc
from tilelang.engine.param import KernelParam
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.engine.phase import ( from tilelang.engine.phase import (
LowerAndLegalize, LowerAndLegalize,
...@@ -117,14 +118,13 @@ def tilelang_callback_hip_compile(code, target): ...@@ -117,14 +118,13 @@ def tilelang_callback_hip_compile(code, target):
return hsaco return hsaco
def extrac_params(func: tir.PrimFunc): def extrac_params(func: tir.PrimFunc) -> List[KernelParam]:
tensor_types = [] tensor_types = []
for var in func.params: for var in func.params:
if var in func.buffer_map: if var in func.buffer_map:
tensor_types.append( tensor_types.append(KernelParam.from_buffer(func.buffer_map[var]))
relay.TensorType(func.buffer_map[var].shape, func.buffer_map[var].dtype))
else: else:
tensor_types.append(relay.scalar_type(var.dtype)) tensor_types.append(KernelParam.from_var(var))
return tensor_types return tensor_types
......
"""The profiler and convert to torch utils"""
from dataclasses import dataclass
from typing import List, Union
import torch
from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var
from tilelang.utils.tensor import map_torch_type
@dataclass
class KernelParam:
"""
Represents parameters for a kernel operation, storing dtype and shape information.
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
"""
dtype: torch.dtype # PyTorch data type of the parameter
shape: List[Union[int, Var]] # List of dimensions, can be integers or TVM variables
@classmethod
def from_buffer(cls, buffer: Buffer):
"""
Creates a KernelParam instance from a TVM Buffer object.
Args:
buffer: TVM Buffer object containing dtype and shape information
Returns:
KernelParam instance with converted dtype and shape
Raises:
ValueError: If dimension type is not supported (not IntImm or Var)
"""
dtype = map_torch_type(buffer.dtype)
shape = []
for s in buffer.shape:
if isinstance(s, IntImm):
shape.append(s.value)
elif isinstance(s, Var):
shape.append(s)
else:
raise ValueError(f"Unsupported dimension type: {type(s)}")
return cls(dtype, shape)
@classmethod
def from_var(cls, var: Var):
"""
Creates a KernelParam instance from a TVM Variable object.
Used for scalar parameters.
Args:
var: TVM Variable object containing dtype information
Returns:
KernelParam instance representing a scalar (empty shape)
"""
return cls(var.dtype, [])
def is_scalar(self) -> bool:
"""
Checks if the parameter represents a scalar value.
Returns:
bool: True if parameter has no dimensions (empty shape), False otherwise
"""
return len(self.shape) == 0
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Callable, Optional from typing import Any, List, Callable, Optional
from tvm.relay import TensorType from tilelang.engine.param import KernelParam
class BaseKernelAdapter(ABC): class BaseKernelAdapter(ABC):
func: Optional[Callable] = None func: Optional[Callable] = None
def __init__(self, mod, params: List[TensorType], result_idx: List[int]) -> None: def __init__(self, mod, params: List[KernelParam], result_idx: List[int]) -> None:
self.mod = mod self.mod = mod
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
......
...@@ -12,7 +12,6 @@ from tilelang.jit.adapter.wrapper import TLWrapper ...@@ -12,7 +12,6 @@ from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type
class CtypesKernelAdapter(BaseKernelAdapter): class CtypesKernelAdapter(BaseKernelAdapter):
...@@ -66,7 +65,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -66,7 +65,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.ir_module = func_or_mod self.ir_module = func_or_mod
# Cache parameter information during initialization # Cache parameter information during initialization
self.param_dtypes = [map_torch_type(param.dtype) for param in params] self.param_dtypes = [param.dtype for param in params]
self.param_shapes = [] self.param_shapes = []
for param in params: for param in params:
native_shape = [] native_shape = []
......
...@@ -5,7 +5,7 @@ import ctypes ...@@ -5,7 +5,7 @@ import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple, Any from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm.relay import TensorType from tilelang.engine.param import KernelParam
from tvm import tir from tvm import tir
from tilelang.jit.adapter.wrapper import TLWrapper from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
...@@ -149,7 +149,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -149,7 +149,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def __init__(self, def __init__(self,
rt_mod, rt_mod,
params: List[TensorType], params: List[KernelParam],
result_idx: List[int], result_idx: List[int],
target: Union[str, Target], target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
......
...@@ -28,7 +28,7 @@ cdef class CythonKernelWrapper: ...@@ -28,7 +28,7 @@ cdef class CythonKernelWrapper:
self.params = params self.params = params
self.lib = lib self.lib = lib
# Convert TVM types to native Python types during initialization # Convert TVM types to native Python types during initialization
self.param_dtypes = [map_torch_type(param.dtype) for param in params] self.param_dtypes = [param.dtype for param in params]
# Convert TVM shape arrays to native Python lists # Convert TVM shape arrays to native Python lists
self.param_shapes = [] self.param_shapes = []
for param in params: for param in params:
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
from typing import List from typing import List
from tilelang.contrib.dlpack import to_pytorch_func from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter from .base import BaseKernelAdapter
from tilelang.utils.tensor import map_torch_type
class TorchDLPackKernelAdapter(BaseKernelAdapter): class TorchDLPackKernelAdapter(BaseKernelAdapter):
...@@ -27,7 +26,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter): ...@@ -27,7 +26,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter):
for i in range(len(self.params)): for i in range(len(self.params)):
if i in self.result_idx: if i in self.result_idx:
dtype = map_torch_type(self.params[i].dtype) dtype = self.params[i].dtype
shape = list(map(int, self.params[i].shape)) shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device) tensor = torch.empty(*shape, dtype=dtype, device=device)
else: else:
......
...@@ -7,6 +7,7 @@ from tvm.tir import PrimFunc ...@@ -7,6 +7,7 @@ from tvm.tir import PrimFunc
from tilelang.jit.adapter import TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter from tilelang.jit.adapter import TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.engine.param import KernelParam
class JITKernel(object): class JITKernel(object):
...@@ -15,15 +16,15 @@ class JITKernel(object): ...@@ -15,15 +16,15 @@ class JITKernel(object):
Attributes Attributes
---------- ----------
rt_module : tvm.runtime.Module rt_mod : tvm.runtime.Module
The runtime module compiled by TVM. The runtime module compiled by TVM.
rt_params : dict params : List[KernelParam]
Parameters for the compiled runtime module (e.g., weights or constants). Parameters for the compiled runtime module (e.g., weights or constants).
torch_function : Callable torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function. The compiled function that can be invoked as a PyTorch-compatible function.
""" """
rt_module: tvm.runtime.Module = None rt_mod: tvm.runtime.Module = None
rt_params: dict = None params: List[KernelParam] = None
adapter: BaseKernelAdapter = None adapter: BaseKernelAdapter = None
torch_function: Callable = None torch_function: Callable = None
...@@ -138,8 +139,8 @@ class JITKernel(object): ...@@ -138,8 +139,8 @@ class JITKernel(object):
rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host) rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host)
# Store the runtime module and parameters for later use. # Store the runtime module and parameters for later use.
self.rt_module = rt_mod self.rt_mod = rt_mod
self.rt_params = params self.params = params
# Create an adapter based on the specified execution backend. # Create an adapter based on the specified execution backend.
if execution_backend == "dlpack": if execution_backend == "dlpack":
...@@ -205,7 +206,8 @@ class JITKernel(object): ...@@ -205,7 +206,8 @@ class JITKernel(object):
Profiler Profiler
A Profiler instance for benchmarking the runtime module. A Profiler instance for benchmarking the runtime module.
""" """
return Profiler(self.rt_module, self.rt_params, self.out_idx, tensor_supply_type) return Profiler(self.params, self.out_idx,
tensor_supply_type).with_default_adapter(self.adapter)
def get_kernel_source(self) -> str: def get_kernel_source(self) -> str:
""" """
...@@ -218,13 +220,13 @@ class JITKernel(object): ...@@ -218,13 +220,13 @@ class JITKernel(object):
""" """
if self.execution_backend in {"ctypes", "cython"}: if self.execution_backend in {"ctypes", "cython"}:
return self.adapter.get_kernel_source() return self.adapter.get_kernel_source()
return self.rt_module.imported_modules[0].get_source() return self.rt_mod.imported_modules[0].get_source()
def get_host_source(self) -> str: def get_host_source(self) -> str:
""" """
Returns the source code of the host function. Returns the source code of the host function.
""" """
return self.rt_module.get_source() return self.rt_mod.get_source()
def run_once(self, func: Optional[Callable] = None) -> None: def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func) return self.get_profiler().run_once(func)
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from typing import List, Literal, Optional, Callable from typing import List, Literal, Optional, Callable, Any
from functools import partial from functools import partial
import torch import torch
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass
import tvm import tvm
from tvm.relay import TensorType
from tilelang.jit.adapter import TorchDLPackKernelAdapter
from tilelang.utils.tensor import ( from tilelang.utils.tensor import (
get_tensor_supply, get_tensor_supply,
TensorSupplyType, TensorSupplyType,
torch_assert_close, torch_assert_close,
adapt_torch2tvm, adapt_torch2tvm,
) )
from tilelang.engine.param import KernelParam
from tilelang.jit.adapter import BaseKernelAdapter
from tilelang.profiler.bench import do_bench
class Profiler(TorchDLPackKernelAdapter): @dataclass
class Profiler:
"""A profiler class for benchmarking and validating kernel implementations.
Attributes:
params: List of kernel parameters defining the input/output specifications
result_idx: Indices indicating which parameters are output tensors
supply_type: Type of tensor supply to use (e.g., random, zeros, etc.)
adapter: Optional kernel adapter for interfacing with different backends
"""
def __init__( params: List[KernelParam]
self, result_idx: List[int]
mod, supply_type: TensorSupplyType
params: List[TensorType], adapter: Optional[BaseKernelAdapter] = None
result_idx: List[int],
supply_type: TensorSupplyType = TensorSupplyType.Normal, def __post_init__(self):
): """Initialize tensor supply after dataclass initialization"""
super().__init__(mod, params, result_idx) self.result_idx = self._legalize_result_idx(self.result_idx)
self.supply = get_tensor_supply(supply_type) self.supply = get_tensor_supply(self.supply_type)
def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
result_idx = []
elif isinstance(result_idx, int):
if result_idx > len(params) or result_idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
elif not isinstance(result_idx, list):
raise ValueError("result_idx should be a list of integers")
return result_idx
def with_default_adapter(self, adapter: BaseKernelAdapter) -> "Profiler":
self.adapter = adapter
return self
def _get_inputs(self, with_output=False): def _get_inputs(self, with_output=False):
ins = [] ins = []
...@@ -42,6 +73,14 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -42,6 +73,14 @@ class Profiler(TorchDLPackKernelAdapter):
rtol: float = 1e-2, rtol: float = 1e-2,
max_mismatched_ratio=0.01, max_mismatched_ratio=0.01,
): ):
"""Validates kernel output against a reference implementation.
Args:
reference_program: Reference implementation to compare against
atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison
max_mismatched_ratio: Maximum allowed ratio of mismatched elements
"""
ins = self._get_inputs() ins = self._get_inputs()
ref_outs = reference_program(*ins) ref_outs = reference_program(*ins)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -70,6 +109,11 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -70,6 +109,11 @@ class Profiler(TorchDLPackKernelAdapter):
) )
def assert_consistent(self, repeat=10): def assert_consistent(self, repeat=10):
"""Checks for kernel consistency across multiple runs.
Args:
repeat: Number of times to repeat the consistency check
"""
# Used to check no race condition inside the kernel # Used to check no race condition inside the kernel
ins = self._get_inputs() ins = self._get_inputs()
ref_outs = self.func(*ins) ref_outs = self.func(*ins)
...@@ -92,8 +136,17 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -92,8 +136,17 @@ class Profiler(TorchDLPackKernelAdapter):
def determine_profiler(self, def determine_profiler(self,
func: Optional[Callable] = None, func: Optional[Callable] = None,
profiler: Literal["torch", "tvm", "auto"] = "auto"): profiler: Literal["torch", "tvm", "auto"] = "auto"):
"""Determines which profiler backend to use based on function type.
Args:
func: Function to be profiled
profiler: Explicitly specified profiler type or "auto" for automatic detection
Returns:
str: The determined profiler type ("torch" or "tvm")
"""
if profiler == "auto": if profiler == "auto":
if func is None or isinstance(func, tvm.runtime.Module): if isinstance(func, tvm.runtime.Module):
return "tvm" return "tvm"
else: else:
return "torch" return "torch"
...@@ -109,8 +162,25 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -109,8 +162,25 @@ class Profiler(TorchDLPackKernelAdapter):
profiler: Literal["torch", "tvm", "auto"] = "auto", profiler: Literal["torch", "tvm", "auto"] = "auto",
input_tensors: List[torch.Tensor] = None, input_tensors: List[torch.Tensor] = None,
) -> float: ) -> float:
"""Benchmarks the execution time of a given function.
Args:
func: Function to benchmark (uses adapter if None)
warmup: Warmup time in milliseconds
rep: Number of repetitions for timing
n_warmup: Number of warmup iterations
n_repeat: Number of timing iterations
profiler: Which profiling backend to use
input_tensors: Optional pre-generated input tensors
Returns:
float: Average execution time in milliseconds
"""
profiler = self.determine_profiler(func, profiler) profiler = self.determine_profiler(func, profiler)
if profiler == "torch": if profiler == "torch":
if func is None:
assert self.adapter is not None, "benchmarking function should be provided"
func = self.adapter
ins = self._get_inputs() if input_tensors is None else input_tensors ins = self._get_inputs() if input_tensors is None else input_tensors
bench_func = partial(func, *ins) bench_func = partial(func, *ins)
return do_bench( return do_bench(
...@@ -121,10 +191,10 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -121,10 +191,10 @@ class Profiler(TorchDLPackKernelAdapter):
_n_repeat=n_repeat, _n_repeat=n_repeat,
) )
elif profiler == "tvm": elif profiler == "tvm":
if func is None: assert func is not None, "func should not be None"
func = self.mod
assert isinstance( assert isinstance(
func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}"
ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors)
target = "cuda" target = "cuda"
...@@ -142,97 +212,10 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -142,97 +212,10 @@ class Profiler(TorchDLPackKernelAdapter):
else: else:
raise ValueError(f"Unknown profiler: {profiler}") raise ValueError(f"Unknown profiler: {profiler}")
@property
def func(self):
assert self.adapter is not None, "adapter should be provided"
return self.adapter
def do_bench( def __call__(self, *args: Any, **kwds: Any) -> Any:
fn, return self.func(*args, **kwds)
warmup=25,
rep=100,
_n_warmup=0,
_n_repeat=0,
grad_to_none=None,
quantiles=None,
fast_flush=True,
return_mode="mean",
) -> float:
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
Returns:
float: The median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
"""
assert return_mode in ["min", "max", "mean", "median"]
fn()
torch.cuda.synchronize()
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
if _n_warmup > 0:
n_warmup = _n_warmup
if _n_repeat > 0:
n_repeat = _n_repeat
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()
"""The profiler and convert to torch utils"""
import torch
from typing import Callable, List, Literal, Optional, Union
def do_bench(
fn: Callable,
warmup: float = 25,
rep: float = 100,
_n_warmup: int = 0,
_n_repeat: int = 0,
grad_to_none: Optional[List[torch.Tensor]] = None,
quantiles: Optional[List[float]] = None,
fast_flush: bool = True,
return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> Union[float, List[float]]:
"""Benchmarks the runtime of a PyTorch function.
This function handles:
- L2 cache flushing between runs for consistent timing
- Automatic warmup and repeat count calculation
- Optional gradient clearing for backward passes
- Multiple measurement modes (mean, median, min, max)
Args:
fn: Function to benchmark
warmup: Target warmup time in milliseconds
rep: Target number of repetitions
_n_warmup: Override for number of warmup iterations
_n_repeat: Override for number of timing iterations
grad_to_none: Tensors whose gradients should be cleared between runs
quantiles: Optional performance percentiles to compute
fast_flush: Whether to use faster L2 cache flushing
return_mode: How to aggregate timing results ("mean", "median", "min", "max")
Returns:
float: Aggregated runtime in milliseconds
"""
assert return_mode in ["min", "max", "mean", "median"]
fn()
torch.cuda.synchronize()
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
if _n_warmup > 0:
n_warmup = _n_warmup
if _n_repeat > 0:
n_repeat = _n_repeat
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from enum import Enum from enum import Enum
import torch import torch
from tvm.relay import TensorType
from tvm.runtime import ndarray from tvm.runtime import ndarray
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack
...@@ -48,16 +45,18 @@ def adapt_torch2tvm(arg): ...@@ -48,16 +45,18 @@ def adapt_torch2tvm(arg):
def get_tensor_supply(supply_type: TensorSupplyType): def get_tensor_supply(supply_type: TensorSupplyType):
def get_tensor(tensor: TensorType) -> torch.Tensor: from tilelang.engine.param import KernelParam
dtype = map_torch_type(str(tensor.dtype))
device = torch.cuda.current_device()
if hasattr(tensor, "shape") and not tensor.shape: def get_tensor(param: KernelParam) -> torch.Tensor:
dtype: torch.dtype = param.dtype
device: torch.device = torch.cuda.current_device()
if hasattr(param, "shape") and not param.shape:
raise ValueError( raise ValueError(
f"TensorType must have a shape, but got {type(tensor)}, " f"TensorType must have a shape, but got {type(param)}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape.") "likely you are trying to generate a random tensor with a dynamic symbolic shape.")
shape = list(map(int, tensor.shape)) shape = list(map(int, param.shape))
if supply_type == TensorSupplyType.Auto: if supply_type == TensorSupplyType.Auto:
if dtype == torch.float16 or dtype == torch.float32: if dtype == torch.float16 or dtype == torch.float32:
return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
...@@ -73,8 +72,8 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -73,8 +72,8 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return torch.ones(*shape, device=device, dtype=dtype) return torch.ones(*shape, device=device, dtype=dtype)
if supply_type == TensorSupplyType.Integer: if supply_type == TensorSupplyType.Integer:
is_unsigned = tensor.dtype.startswith("uint") is_unsigned = str(dtype).removeprefix("torch.").startswith("uint")
is_float8 = tensor.dtype.endswith("float8") is_float8 = str(dtype).removeprefix("torch.").startswith("float8")
if is_unsigned: if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8: elif is_float8:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment