"host/vscode:/vscode.git/clone" did not exist on "80120f0a0c524d1efc0249926a73d5020f0efd67"
Commit 273be768 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Skip patchelf if not installed (#477)

* [Refactor] Enhance TMA barrier validation and support for additional architectures

* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`.
* Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.

* Enhance logging in setup.py and refactor TMA architecture checks in phase.py

* Added logging configuration to setup.py, replacing print statements with logger for better traceability.
* Updated download and extraction functions to use logger for status messages.
* Refactored TMA architecture checks in phase.py to utilize the new `have_tma` function for improved clarity and maintainability.
* Introduced support for additional compute capabilities in nvcc.py, including TMA support checks.

* Update documentation for get_target_compute_version to reflect correct GPU compute capability range

* Refactor have_tma function to accept tvm.target.Target instead of compute_version

* Updated the `have_tma` function in nvcc.py to take a `target` parameter, improving clarity and usability.
* Adjusted calls to `have_tma` in phase.py to pass the target directly, enhancing maintainability and consistency in TMA support checks.
parent 8dec14e0
...@@ -18,6 +18,15 @@ import platform ...@@ -18,6 +18,15 @@ import platform
import multiprocessing import multiprocessing
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
import importlib import importlib
import logging
# Configure logging with basic settings
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
# Environment variables False/True # Environment variables False/True
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true" PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
...@@ -192,7 +201,7 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty" ...@@ -192,7 +201,7 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
download_url = f"{base_url}/{file_name}" download_url = f"{base_url}/{file_name}"
# Download the file # Download the file
print(f"Downloading {file_name} from {download_url}") logger.info(f"Downloading {file_name} from {download_url}")
with urllib.request.urlopen(download_url) as response: with urllib.request.urlopen(download_url) as response:
if response.status != 200: if response.status != 200:
raise Exception(f"Download failed with status code {response.status}") raise Exception(f"Download failed with status code {response.status}")
...@@ -205,11 +214,11 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty" ...@@ -205,11 +214,11 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
os.remove(os.path.join(extract_path, file_name)) os.remove(os.path.join(extract_path, file_name))
# Extract the file # Extract the file
print(f"Extracting {file_name} to {extract_path}") logger.info(f"Extracting {file_name} to {extract_path}")
with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar: with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar:
tar.extractall(path=extract_path) tar.extractall(path=extract_path)
print("Download and extraction completed successfully.") logger.info("Download and extraction completed successfully.")
return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", ""))) return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", "")))
...@@ -235,7 +244,7 @@ def update_submodules(): ...@@ -235,7 +244,7 @@ def update_submodules():
return False return False
if not is_git_repo(): if not is_git_repo():
print("Info: Not a git repository, skipping submodule update.") logger.info("Info: Not a git repository, skipping submodule update.")
return return
try: try:
...@@ -285,7 +294,15 @@ def patch_libs(libpath): ...@@ -285,7 +294,15 @@ def patch_libs(libpath):
and have a hard-coded rpath. and have a hard-coded rpath.
Set rpath to the directory of libs so auditwheel works well. Set rpath to the directory of libs so auditwheel works well.
""" """
subprocess.run(['patchelf', '--set-rpath', '$ORIGIN', libpath]) # check if patchelf is installed
# find patchelf in the system
patchelf_path = shutil.which("patchelf")
if not patchelf_path:
logger.warning(
"patchelf is not installed, which is required for auditwheel to work for compatible wheels."
)
return
subprocess.run([patchelf_path, '--set-rpath', '$ORIGIN', libpath])
class TileLangBuilPydCommand(build_py): class TileLangBuilPydCommand(build_py):
...@@ -299,11 +316,11 @@ class TileLangBuilPydCommand(build_py): ...@@ -299,11 +316,11 @@ class TileLangBuilPydCommand(build_py):
ext_modules = build_ext_cmd.extensions ext_modules = build_ext_cmd.extensions
for ext in ext_modules: for ext in ext_modules:
extdir = build_ext_cmd.get_ext_fullpath(ext.name) extdir = build_ext_cmd.get_ext_fullpath(ext.name)
print(f"Extension {ext.name} output directory: {extdir}") logger.info(f"Extension {ext.name} output directory: {extdir}")
ext_output_dir = os.path.dirname(extdir) ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}") logger.info(f"Extension output directory (parent): {ext_output_dir}")
print(f"Build temp directory: {build_temp_dir}") logger.info(f"Build temp directory: {build_temp_dir}")
# copy cython files # copy cython files
CYTHON_SRC = [ CYTHON_SRC = [
...@@ -370,12 +387,12 @@ class TileLangBuilPydCommand(build_py): ...@@ -370,12 +387,12 @@ class TileLangBuilPydCommand(build_py):
os.makedirs(target_dir_release, exist_ok=True) os.makedirs(target_dir_release, exist_ok=True)
os.makedirs(target_dir_develop, exist_ok=True) os.makedirs(target_dir_develop, exist_ok=True)
shutil.copy2(source_lib_file, target_dir_release) shutil.copy2(source_lib_file, target_dir_release)
print(f"Copied {source_lib_file} to {target_dir_release}") logger.info(f"Copied {source_lib_file} to {target_dir_release}")
shutil.copy2(source_lib_file, target_dir_develop) shutil.copy2(source_lib_file, target_dir_develop)
print(f"Copied {source_lib_file} to {target_dir_develop}") logger.info(f"Copied {source_lib_file} to {target_dir_develop}")
os.remove(source_lib_file) os.remove(source_lib_file)
else: else:
print(f"WARNING: {item} not found in any expected directories!") logger.info(f"WARNING: {item} not found in any expected directories!")
TVM_CONFIG_ITEMS = [ TVM_CONFIG_ITEMS = [
f"{build_temp_dir}/config.cmake", f"{build_temp_dir}/config.cmake",
...@@ -391,7 +408,7 @@ class TileLangBuilPydCommand(build_py): ...@@ -391,7 +408,7 @@ class TileLangBuilPydCommand(build_py):
if os.path.exists(source_dir): if os.path.exists(source_dir):
shutil.copy2(source_dir, target_dir) shutil.copy2(source_dir, target_dir)
else: else:
print(f"INFO: {source_dir} does not exist.") logger.info(f"INFO: {source_dir} does not exist.")
TVM_PACAKGE_ITEMS = [ TVM_PACAKGE_ITEMS = [
"3rdparty/tvm/src", "3rdparty/tvm/src",
...@@ -486,7 +503,7 @@ class TileLangDevelopCommand(develop): ...@@ -486,7 +503,7 @@ class TileLangDevelopCommand(develop):
""" """
def run(self): def run(self):
print("Running TileLangDevelopCommand") logger.info("Running TileLangDevelopCommand")
# 1. Build the C/C++ extension modules # 1. Build the C/C++ extension modules
self.run_command("build_ext") self.run_command("build_ext")
...@@ -494,10 +511,10 @@ class TileLangDevelopCommand(develop): ...@@ -494,10 +511,10 @@ class TileLangDevelopCommand(develop):
ext_modules = build_ext_cmd.extensions ext_modules = build_ext_cmd.extensions
for ext in ext_modules: for ext in ext_modules:
extdir = build_ext_cmd.get_ext_fullpath(ext.name) extdir = build_ext_cmd.get_ext_fullpath(ext.name)
print(f"Extension {ext.name} output directory: {extdir}") logger.info(f"Extension {ext.name} output directory: {extdir}")
ext_output_dir = os.path.dirname(extdir) ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}") logger.info(f"Extension output directory (parent): {ext_output_dir}")
# Copy the built TVM to the package directory # Copy the built TVM to the package directory
TVM_PREBUILD_ITEMS = [ TVM_PREBUILD_ITEMS = [
...@@ -521,7 +538,7 @@ class TileLangDevelopCommand(develop): ...@@ -521,7 +538,7 @@ class TileLangDevelopCommand(develop):
# remove the original file # remove the original file
os.remove(source_lib_file) os.remove(source_lib_file)
else: else:
print(f"INFO: {source_lib_file} does not exist.") logger.info(f"INFO: {source_lib_file} does not exist.")
class CMakeExtension(Extension): class CMakeExtension(Extension):
......
...@@ -268,7 +268,7 @@ def get_target_compute_version(target=None): ...@@ -268,7 +268,7 @@ def get_target_compute_version(target=None):
Returns Returns
------- -------
compute_version : str compute_version : str
compute capability of a GPU (e.g. "8.6") compute capability of a GPU (e.g. "8.6" or "9.0")
""" """
# 1. input target object # 1. input target object
# 2. Target.current() # 2. Target.current()
...@@ -277,10 +277,17 @@ def get_target_compute_version(target=None): ...@@ -277,10 +277,17 @@ def get_target_compute_version(target=None):
arch = target.arch.split("_")[1] arch = target.arch.split("_")[1]
if len(arch) == 2: if len(arch) == 2:
major, minor = arch major, minor = arch
# Handle old format like sm_89
return major + "." + minor return major + "." + minor
elif len(arch) == 3: elif len(arch) == 3:
# This is for arch like "sm_90a" major = int(arch[0])
major, minor, suffix = arch if major < 2:
major = arch[0:2]
minor = arch[2]
return major + "." + minor
else:
# This is for arch like "sm_90a"
major, minor, suffix = arch
return major + "." + minor + "." + suffix return major + "." + minor + "." + suffix
# 3. GPU compute version # 3. GPU compute version
...@@ -414,6 +421,23 @@ def have_fp8(compute_version): ...@@ -414,6 +421,23 @@ def have_fp8(compute_version):
return any(conditions) return any(conditions)
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True)
def have_tma(target):
"""Whether TMA support is provided in the specified compute capability or not
Parameters
----------
target : tvm.target.Target
The compilation target
"""
compute_version = get_target_compute_version(target)
major, minor = parse_compute_version(compute_version)
# TMA is supported in Ada Lovelace (9.0) or later architectures.
conditions = [False]
conditions.append(major >= 9)
return any(conditions)
def get_nvcc_compiler() -> str: def get_nvcc_compiler() -> str:
"""Get the path to the nvcc compiler""" """Get the path to the nvcc compiler"""
return os.path.join(find_cuda_path(), "bin", "nvcc") return os.path.join(find_cuda_path(), "bin", "nvcc")
...@@ -2,16 +2,15 @@ from tvm import tir, IRModule ...@@ -2,16 +2,15 @@ from tvm import tir, IRModule
from tvm.target import Target from tvm.target import Target
import tilelang import tilelang
from tilelang.transform import PassContext from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma
from typing import Optional from typing import Optional
SUPPORTED_TMA_ARCHS = {"sm_90", "sm_90a"}
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool: target: Optional[Target] = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in SUPPORTED_TMA_ARCHS: if not have_tma(target):
return False return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
...@@ -20,7 +19,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -20,7 +19,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
def allow_fence_proxy(target: Optional[Target] = None) -> bool: def allow_fence_proxy(target: Optional[Target] = None) -> bool:
return target.arch in SUPPORTED_TMA_ARCHS return have_tma(target)
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment