Commit 62a8d7f0 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Add commit ID to versioning and improve logging initialization (#524)

* Updated `get_tilelang_version` to include an optional commit ID in the version string.
* Enhanced the `TileLangBuilPydCommand` to write the version with commit ID to the VERSION file during the build process.
* Introduced a new function `get_git_commit_id` in `version.py` to retrieve the current git commit hash.
* Refactored logger initialization in `autotuner/__init__.py` to ensure handlers are set up only once, improving performance and clarity.
* Minor fixes in `flatten_buffer.cc` and `kernel_cache.py` for better handling of versioning and logging.
parent 41c51d07
...@@ -130,7 +130,7 @@ def get_rocm_version(): ...@@ -130,7 +130,7 @@ def get_rocm_version():
return LooseVersion("5.0.0") return LooseVersion("5.0.0")
def get_tilelang_version(with_cuda=True, with_system_info=True) -> str: def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str:
version = find_version(get_path(".", "VERSION")) version = find_version(get_path(".", "VERSION"))
local_version_parts = [] local_version_parts = []
if with_system_info: if with_system_info:
...@@ -150,6 +150,18 @@ def get_tilelang_version(with_cuda=True, with_system_info=True) -> str: ...@@ -150,6 +150,18 @@ def get_tilelang_version(with_cuda=True, with_system_info=True) -> str:
if local_version_parts: if local_version_parts:
version += f"+{'.'.join(local_version_parts)}" version += f"+{'.'.join(local_version_parts)}"
if with_commit_id:
commit_id = None
try:
commit_id = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
stderr=subprocess.DEVNULL,
encoding='utf-8').strip()
except subprocess.SubprocessError as error:
raise RuntimeError("Failed to get git commit id") from error
if commit_id:
version += f"+{commit_id}"
return version return version
...@@ -473,6 +485,18 @@ class TileLangBuilPydCommand(build_py): ...@@ -473,6 +485,18 @@ class TileLangBuilPydCommand(build_py):
for item in TL_CONFIG_ITEMS: for item in TL_CONFIG_ITEMS:
source_dir = os.path.join(ROOT_DIR, item) source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
# if is VERSION file, replace the content with the new version with commit id
if not PYPI_BUILD and item == "VERSION":
version = get_tilelang_version(
with_cuda=False, with_system_info=False, with_commit_id=True)
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
with open(os.path.join(target_dir, item), "w") as f:
print(f"Writing {version} to {os.path.join(target_dir, item)}")
f.write(version)
continue
if os.path.isdir(source_dir): if os.path.isdir(source_dir):
self.mkpath(target_dir) self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir) distutils.dir_util.copy_tree(source_dir, target_dir)
...@@ -489,7 +513,7 @@ class TileLangSdistCommand(sdist): ...@@ -489,7 +513,7 @@ class TileLangSdistCommand(sdist):
def make_distribution(self): def make_distribution(self):
self.distribution.metadata.name = PACKAGE_NAME self.distribution.metadata.name = PACKAGE_NAME
self.distribution.metadata.version = get_tilelang_version( self.distribution.metadata.version = get_tilelang_version(
with_cuda=False, with_system_info=False) with_cuda=False, with_system_info=False, with_commit_id=False)
super().make_distribution() super().make_distribution()
...@@ -572,9 +596,10 @@ class CMakeBuild(build_ext): ...@@ -572,9 +596,10 @@ class CMakeBuild(build_ext):
# Check if CMake is installed and accessible by attempting to run 'cmake --version'. # Check if CMake is installed and accessible by attempting to run 'cmake --version'.
try: try:
subprocess.check_output(["cmake", "--version"]) subprocess.check_output(["cmake", "--version"])
except OSError as e: except OSError as error:
# If CMake is not found, raise an error. # If CMake is not found, raise an error.
raise RuntimeError("CMake must be installed to build the following extensions") from e raise RuntimeError(
"CMake must be installed to build the following extensions") from error
update_submodules() update_submodules()
......
...@@ -281,11 +281,12 @@ private: ...@@ -281,11 +281,12 @@ private:
auto int_bound = analyzer_->const_int_bound(index); auto int_bound = analyzer_->const_int_bound(index);
DataType dtype = index->dtype; DataType dtype = index->dtype;
if (dtype.is_int() && dtype.bits() < 64) { if (dtype.is_int() && dtype.bits() < 64) {
int64_t max_value = int_bound->max_value + 1; int64_t max_value = int_bound->max_value;
int64_t min_value = int_bound->min_value; int64_t min_value = int_bound->min_value;
const int64_t type_max = (1LL << (dtype.bits() - 1)); const int64_t type_max = (1LL << (dtype.bits() - 1));
const int64_t type_min = -(1LL << (dtype.bits() - 1)); const int64_t type_min = -(1LL << (dtype.bits() - 1));
if (max_value >= type_max || min_value < type_min) {
if (max_value >= (type_max - 1) || min_value < type_min) {
Int64Promoter promoter; Int64Promoter promoter;
for (auto &index : flattened_indices) { for (auto &index : flattened_indices) {
safe_indices.push_back(promoter(index)); safe_indices.push_back(promoter(index));
......
...@@ -44,18 +44,24 @@ logger = logging.getLogger(__name__) ...@@ -44,18 +44,24 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
logger.propagate = False logger.propagate = False
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') # Lazy handler initialization flag
_logger_handlers_initialized = False
file_handler = logging.FileHandler('autotuner.log', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter) def _init_logger_handlers():
global _logger_handlers_initialized
console_handler = logging.StreamHandler(sys.stdout) if _logger_handlers_initialized:
console_handler.setLevel(logging.INFO) return
console_handler.setFormatter(formatter) formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
file_handler = logging.FileHandler('autotuner.log', mode='w')
logger.addHandler(file_handler) file_handler.setLevel(logging.DEBUG)
logger.addHandler(console_handler) file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
_logger_handlers_initialized = True
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -239,6 +245,7 @@ class AutoTuner: ...@@ -239,6 +245,7 @@ class AutoTuner:
Returns: Returns:
AutotuneResult: Results of the auto-tuning process. AutotuneResult: Results of the auto-tuning process.
""" """
_init_logger_handlers()
sig = inspect.signature(self.fn) sig = inspect.signature(self.fn)
keys = list(sig.parameters.keys()) keys = list(sig.parameters.keys())
bound_args = sig.bind() bound_args = sig.bind()
......
...@@ -15,6 +15,7 @@ import cloudpickle ...@@ -15,6 +15,7 @@ import cloudpickle
import logging import logging
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.version import __version__
KERNEL_PATH = "kernel.cu" KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu" WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
...@@ -87,6 +88,7 @@ class KernelCache: ...@@ -87,6 +88,7 @@ class KernelCache:
""" """
func_binary = cloudpickle.dumps(func.script()) func_binary = cloudpickle.dumps(func.script())
key_data = { key_data = {
"version": __version__,
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]), "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple( "args_repr": tuple(
...@@ -147,6 +149,8 @@ class KernelCache: ...@@ -147,6 +149,8 @@ class KernelCache:
with self._lock: with self._lock:
# First check in-memory cache # First check in-memory cache
if key in self._memory_cache: if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.jit` instead of direct kernel caching.")
return self._memory_cache[key] return self._memory_cache[key]
# Then check disk cache # Then check disk cache
......
import os import os
import subprocess
from typing import Union
# Get the absolute path of the current Python script's directory # Get the absolute path of the current Python script's directory
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
...@@ -22,5 +24,24 @@ else: ...@@ -22,5 +24,24 @@ else:
with open(version_file_path, "r") as version_file: with open(version_file_path, "r") as version_file:
__version__ = version_file.read().strip() __version__ = version_file.read().strip()
def get_git_commit_id() -> Union[str, None]:
"""Get the current git commit hash.
Returns:
str | None: The git commit hash if available, None otherwise.
"""
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'],
stderr=subprocess.DEVNULL,
encoding='utf-8').strip()
except subprocess.SubprocessError:
return None
# Append git commit hash to version if not already present
if "+" not in __version__ and (commit_id := get_git_commit_id()):
__version__ = f"{__version__}+{commit_id}"
# Define the public API for the module # Define the public API for the module
__all__ = ["__version__"] __all__ = ["__version__"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment