Commit b1e6b27f authored by pigKiller's avatar pigKiller Committed by LeiWang1999
Browse files

[AMD][Setup] Support HIP in setup.py (#369)



* add hip setup support

* add env.find_hip func

* Delete install_hip.sh as we already have install_rocm.sh

* modify hip to rocm

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent c4638d65
...@@ -26,6 +26,8 @@ ROOT_DIR = os.path.dirname(__file__) ...@@ -26,6 +26,8 @@ ROOT_DIR = os.path.dirname(__file__)
# Add LLVM control environment variable # Add LLVM control environment variable
USE_LLVM = os.environ.get("USE_LLVM", "False").lower() == "true" USE_LLVM = os.environ.get("USE_LLVM", "False").lower() == "true"
# Add ROCM control environment variable
USE_ROCM = os.environ.get("USE_ROCM", "False").lower() == "true"
def load_module_from_path(module_name, path): def load_module_from_path(module_name, path):
...@@ -37,9 +39,20 @@ def load_module_from_path(module_name, path): ...@@ -37,9 +39,20 @@ def load_module_from_path(module_name, path):
envs = load_module_from_path('env', os.path.join(ROOT_DIR, PACKAGE_NAME, 'env.py')) envs = load_module_from_path('env', os.path.join(ROOT_DIR, PACKAGE_NAME, 'env.py'))
CUDA_HOME = envs.CUDA_HOME CUDA_HOME = envs.CUDA_HOME
ROCM_HOME = envs.ROCM_HOME
# Check if both CUDA and ROCM are enabled
if USE_ROCM and not ROCM_HOME:
raise ValueError("ROCM support is enabled (USE_ROCM=True) but ROCM_HOME is not set or detected.")
if not USE_ROCM and not CUDA_HOME:
raise ValueError("CUDA support is enabled by default (USE_ROCM=False) but CUDA_HOME is not set or detected.")
assert CUDA_HOME, "Failed to automatically detect CUDA installation. Please set the CUDA_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda)." # Ensure one of CUDA or ROCM is available
if not (CUDA_HOME or ROCM_HOME):
raise ValueError("Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm).")
# TileLang only supports Linux platform # TileLang only supports Linux platform
assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)." assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)."
...@@ -82,15 +95,45 @@ def get_nvcc_cuda_version(): ...@@ -82,15 +95,45 @@ def get_nvcc_cuda_version():
return nvcc_cuda_version return nvcc_cuda_version
def get_rocm_version():
"""Get the ROCM version from rocminfo."""
rocm_output = subprocess.check_output(["rocminfo"], universal_newlines=True)
# Parse ROCM version from output
# Example output: ROCM version: x.y.z-...
match = re.search(r'ROCm Version: (\d+\.\d+\.\d+)', rocm_output)
if match:
return LooseVersion(match.group(1))
else:
rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm")
rocm_version_file = os.path.join(rocm_path, "lib", "cmake", "rocm", "rocm-config-version.cmake")
if os.path.exists(rocm_version_file):
with open(rocm_version_file, "r") as f:
content = f.read()
match = re.search(r'set\(PACKAGE_VERSION "(\d+\.\d+\.\d+)"', content)
if match:
return LooseVersion(match.group(1))
# return a default
return LooseVersion("5.0.0")
def get_tilelang_version(with_cuda=True, with_system_info=True) -> str: def get_tilelang_version(with_cuda=True, with_system_info=True) -> str:
version = find_version(get_path(".", "VERSION")) version = find_version(get_path(".", "VERSION"))
local_version_parts = [] local_version_parts = []
if with_system_info: if with_system_info:
local_version_parts.append(get_system_info().replace("-", ".")) local_version_parts.append(get_system_info().replace("-", "."))
if with_cuda: if with_cuda:
if USE_ROCM:
if ROCM_HOME:
rocm_version = str(get_rocm_version())
rocm_version_str = rocm_version.replace(".", "")[:3]
local_version_parts.append(f"rocm{rocm_version_str}")
else:
if CUDA_HOME:
cuda_version = str(get_nvcc_cuda_version()) cuda_version = str(get_nvcc_cuda_version())
cuda_version_str = cuda_version.replace(".", "")[:3] cuda_version_str = cuda_version.replace(".", "")[:3]
local_version_parts.append(f"cu{cuda_version_str}") local_version_parts.append(f"cu{cuda_version_str}")
if local_version_parts: if local_version_parts:
version += f"+{'.'.join(local_version_parts)}" version += f"+{'.'.join(local_version_parts)}"
return version return version
...@@ -205,10 +248,15 @@ def build_csrc(llvm_config_path): ...@@ -205,10 +248,15 @@ def build_csrc(llvm_config_path):
# Copy the config.cmake as a baseline # Copy the config.cmake as a baseline
if not os.path.exists("config.cmake"): if not os.path.exists("config.cmake"):
shutil.copy("../3rdparty/tvm/cmake/config.cmake", "config.cmake") shutil.copy("../3rdparty/tvm/cmake/config.cmake", "config.cmake")
# Set LLVM path and enable CUDA in config.cmake # Set LLVM path and enable CUDA or ROCM in config.cmake
with open("config.cmake", "a") as config_file: with open("config.cmake", "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n") config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
if USE_ROCM:
config_file.write(f"set(USE_ROCM {ROCM_HOME})\n")
config_file.write("set(USE_CUDA OFF)\n")
else:
config_file.write(f"set(USE_CUDA {CUDA_HOME})\n") config_file.write(f"set(USE_CUDA {CUDA_HOME})\n")
config_file.write("set(USE_ROCM OFF)\n")
# Run CMake and make # Run CMake and make
try: try:
subprocess.check_call(["cmake", ".."]) subprocess.check_call(["cmake", ".."])
...@@ -560,7 +608,12 @@ class CMakeBuild(build_ext): ...@@ -560,7 +608,12 @@ class CMakeBuild(build_ext):
# Append some configuration variables to 'config.cmake' # Append some configuration variables to 'config.cmake'
with open(dst_config_cmake, "a") as config_file: with open(dst_config_cmake, "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n") config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
if USE_ROCM:
config_file.write(f"set(USE_ROCM {ROCM_HOME})\n")
config_file.write("set(USE_CUDA OFF)\n")
else:
config_file.write(f"set(USE_CUDA {CUDA_HOME})\n") config_file.write(f"set(USE_CUDA {CUDA_HOME})\n")
config_file.write("set(USE_ROCM OFF)\n")
# Run CMake to configure the project with the given arguments. # Run CMake to configure the project with the given arguments.
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
...@@ -581,7 +634,7 @@ setup( ...@@ -581,7 +634,7 @@ setup(
long_description=read_readme(), long_description=read_readme(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
platforms=[ platforms=[
"Environment :: GPU :: NVIDIA CUDA", "Environment :: GPU :: NVIDIA CUDA" if not USE_ROCM else "Environment :: GPU :: AMD ROCm",
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
], ],
license="MIT", license="MIT",
......
...@@ -32,7 +32,22 @@ def _find_cuda_home() -> str: ...@@ -32,7 +32,22 @@ def _find_cuda_home() -> str:
return cuda_home if cuda_home is not None else "" return cuda_home if cuda_home is not None else ""
def _find_rocm_home() -> str:
"""Find the ROCM install path."""
rocm_home = os.environ.get('ROCM_PATH') or os.environ.get('ROCM_HOME')
if rocm_home is None:
rocmcc_path = shutil.which("hipcc")
if rocmcc_path is not None:
rocm_home = os.path.dirname(os.path.dirname(rocmcc_path))
else:
rocm_home = '/opt/rocm'
if not os.path.exists(rocm_home):
rocm_home = None
return rocm_home if rocm_home is not None else ""
CUDA_HOME = _find_cuda_home() CUDA_HOME = _find_cuda_home()
ROCM_HOME = _find_rocm_home()
CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None) CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None)
COMPOSABLE_KERNEL_INCLUDE_DIR: str = os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) COMPOSABLE_KERNEL_INCLUDE_DIR: str = os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None)
...@@ -174,6 +189,7 @@ __all__ = [ ...@@ -174,6 +189,7 @@ __all__ = [
"TVM_LIBRARY_PATH", "TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH", "TILELANG_TEMPLATE_PATH",
"CUDA_HOME", "CUDA_HOME",
"ROCM_HOME",
"TILELANG_CACHE_DIR", "TILELANG_CACHE_DIR",
"enable_cache", "enable_cache",
"disable_cache", "disable_cache",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment