Unverified Commit b8003a28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Include PrimFunc name in memory cache logs for better debugging (#1437)

* Added the `get_prim_func_name` utility to extract human-readable function names from TVM PrimFuncs.
* Updated memory cache logging in `AutoTuner` and `KernelCache` classes to include the kernel name, improving clarity during cache hits.
* Enhanced debug logging to provide more informative messages when checking disk cache for kernels.
parent 2feaa41e
...@@ -37,6 +37,7 @@ from pathlib import Path ...@@ -37,6 +37,7 @@ from pathlib import Path
from tilelang import env from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.utils.language import get_prim_func_name
from tilelang.autotuner.capture import get_autotune_inputs from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang import __version__ from tilelang import __version__
...@@ -332,11 +333,15 @@ class AutoTuner: ...@@ -332,11 +333,15 @@ class AutoTuner:
if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): if env.is_cache_enabled() and not env.is_autotune_cache_disabled():
# First check in-memory cache # First check in-memory cache
if key in self._memory_cache: if key in self._memory_cache:
# Include PrimFunc name when hitting autotuner memory cache
cached_result = self._memory_cache[key]
prim = getattr(cached_result, "func", None)
kernel_name = get_prim_func_name(prim, "<unknown>")
logger.warning( logger.warning(
"Found kernel in memory cache. For better performance," "Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.",
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel." kernel_name,
) )
return self._memory_cache[key] return cached_result
# Then check disk cache # Then check disk cache
result = self._load_result_from_disk(key) result = self._load_result_from_disk(key)
......
...@@ -16,6 +16,7 @@ from tvm.target import Target ...@@ -16,6 +16,7 @@ from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tvm.runtime import Executable from tvm.runtime import Executable
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang.utils.language import get_prim_func_name
from tilelang import env from tilelang import env
from tilelang.jit import JITKernel from tilelang.jit import JITKernel
from tilelang import __version__ from tilelang import __version__
...@@ -179,13 +180,16 @@ class KernelCache: ...@@ -179,13 +180,16 @@ class KernelCache:
with self._lock: with self._lock:
# First check in-memory cache # First check in-memory cache
if key in self._memory_cache: if key in self._memory_cache:
# Include kernel name for easier debugging when hitting memory cache
kernel_name = get_prim_func_name(func, "<unknown>")
self.logger.warning( self.logger.warning(
"Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching." "Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.",
kernel_name,
) )
return self._memory_cache[key] return self._memory_cache[key]
if verbose: if verbose:
self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}") self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '<unknown>')}")
# Then check disk cache # Then check disk cache
kernel = self._load_kernel_from_disk( kernel = self._load_kernel_from_disk(
...@@ -193,13 +197,13 @@ class KernelCache: ...@@ -193,13 +197,13 @@ class KernelCache:
) )
if kernel is not None: if kernel is not None:
if verbose: if verbose:
self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}") self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '<unknown>')}")
# Populate memory cache with disk result # Populate memory cache with disk result
self._memory_cache[key] = kernel self._memory_cache[key] = kernel
return kernel return kernel
if verbose: if verbose:
self.logger.debug(f"No cached kernel for {func.attrs['global_symbol']}") self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '<unknown>')}")
# Compile kernel if cache miss; leave critical section # Compile kernel if cache miss; leave critical section
kernel = JITKernel( kernel = JITKernel(
func, func,
......
...@@ -16,5 +16,6 @@ from .language import ( ...@@ -16,5 +16,6 @@ from .language import (
is_full_region, # noqa: F401 is_full_region, # noqa: F401
to_buffer_region, # noqa: F401 to_buffer_region, # noqa: F401
get_buffer_region_from_load, # noqa: F401 get_buffer_region_from_load, # noqa: F401
get_prim_func_name, # noqa: F401
) )
from .deprecated import deprecated # noqa: F401 from .deprecated import deprecated # noqa: F401
...@@ -478,3 +478,27 @@ def is_full_region(buffer_region: BufferRegion) -> bool: ...@@ -478,3 +478,27 @@ def is_full_region(buffer_region: BufferRegion) -> bool:
if not expr_equal(r.extent, dim): if not expr_equal(r.extent, dim):
return False return False
return True return True
def get_prim_func_name(func: PrimFunc | None, default: str | None = None) -> str | None:
"""
Extract a human‑readable function name from a TVM PrimFunc.
Prefer the `global_symbol` attribute set on the PrimFunc. If it is missing
(e.g., private PrimFunc without a global symbol), return the provided
`default` value.
Args:
func: TVM PrimFunc instance or None.
default: Fallback name to return when no name can be determined.
Returns:
The function name as a string, or `default` when unavailable.
"""
if func is None:
return default
try:
name = func.attrs["global_symbol"]
return str(name) if name is not None else default
except Exception:
return default
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment