Unverified Commit b8003a28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Include PrimFunc name in memory cache logs for better debugging (#1437)

* Added the `get_prim_func_name` utility to extract human-readable function names from TVM PrimFuncs.
* Updated memory cache logging in `AutoTuner` and `KernelCache` classes to include the kernel name, improving clarity during cache hits.
* Enhanced debug logging to provide more informative messages when checking disk cache for kernels.
parent 2feaa41e
......@@ -37,6 +37,7 @@ from pathlib import Path
from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.utils.language import get_prim_func_name
from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target
from tilelang import __version__
......@@ -332,11 +333,15 @@ class AutoTuner:
if env.is_cache_enabled() and not env.is_autotune_cache_disabled():
# First check in-memory cache
if key in self._memory_cache:
# Include PrimFunc name when hitting autotuner memory cache
cached_result = self._memory_cache[key]
prim = getattr(cached_result, "func", None)
kernel_name = get_prim_func_name(prim, "<unknown>")
logger.warning(
"Found kernel in memory cache. For better performance,"
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
"Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.",
kernel_name,
)
return self._memory_cache[key]
return cached_result
# Then check disk cache
result = self._load_result_from_disk(key)
......
......@@ -16,6 +16,7 @@ from tvm.target import Target
from tvm.tir import PrimFunc
from tvm.runtime import Executable
from tilelang.engine.param import KernelParam
from tilelang.utils.language import get_prim_func_name
from tilelang import env
from tilelang.jit import JITKernel
from tilelang import __version__
......@@ -179,13 +180,16 @@ class KernelCache:
with self._lock:
# First check in-memory cache
if key in self._memory_cache:
# Include kernel name for easier debugging when hitting memory cache
kernel_name = get_prim_func_name(func, "<unknown>")
self.logger.warning(
"Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching."
"Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.",
kernel_name,
)
return self._memory_cache[key]
if verbose:
self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}")
self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '<unknown>')}")
# Then check disk cache
kernel = self._load_kernel_from_disk(
......@@ -193,13 +197,13 @@ class KernelCache:
)
if kernel is not None:
if verbose:
self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}")
self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '<unknown>')}")
# Populate memory cache with disk result
self._memory_cache[key] = kernel
return kernel
if verbose:
self.logger.debug(f"No cached kernel for {func.attrs['global_symbol']}")
self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '<unknown>')}")
# Compile kernel if cache miss; leave critical section
kernel = JITKernel(
func,
......
......@@ -16,5 +16,6 @@ from .language import (
is_full_region, # noqa: F401
to_buffer_region, # noqa: F401
get_buffer_region_from_load, # noqa: F401
get_prim_func_name, # noqa: F401
)
from .deprecated import deprecated # noqa: F401
......@@ -478,3 +478,27 @@ def is_full_region(buffer_region: BufferRegion) -> bool:
if not expr_equal(r.extent, dim):
return False
return True
def get_prim_func_name(func: PrimFunc | None, default: str | None = None) -> str | None:
"""
Extract a human‑readable function name from a TVM PrimFunc.
Prefer the `global_symbol` attribute set on the PrimFunc. If it is missing
(e.g., private PrimFunc without a global symbol), return the provided
`default` value.
Args:
func: TVM PrimFunc instance or None.
default: Fallback name to return when no name can be determined.
Returns:
The function name as a string, or `default` when unavailable.
"""
if func is None:
return default
try:
name = func.attrs["global_symbol"]
return str(name) if name is not None else default
except Exception:
return default
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment