Commit 901deae1 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Enhancement] Improve CUDA path detection (#157)

* [Typo] Fix formatting in installation instructions in README.md

* [Enhancement] Improve CUDA path detection and update configuration handling

* fix typo

* remove IS_WINDOWS constant

* lint fix

* Improve error messages for CUDA detection failure

* lint fix

* lint fix

* Fix .gitignore to correctly include venv directory
parent cfcbcf1e
......@@ -93,7 +93,11 @@ cd build
echo "Configuring TVM build with LLVM and CUDA paths..."
echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake && echo "set(USE_CUDA /usr/local/cuda)" >> config.cmake
echo "set(USE_LLVM \"$LLVM_CONFIG_PATH\")" >> config.cmake && \
CUDA_HOME=$(python -c "import sys; sys.path.append('../tilelang'); from env import CUDA_HOME; print(CUDA_HOME)") || \
{ echo "ERROR: Failed to retrieve CUDA_HOME via Python script." >&2; exit 1; } && \
{ [ -n "$CUDA_HOME" ] || { echo "ERROR: CUDA_HOME is empty, check CUDA installation or _find_cuda_home() in setup.py" >&2; exit 1; }; } && \
echo "set(USE_CUDA \"$CUDA_HOME\")" >> config.cmake
echo "Running CMake for TileLang..."
cmake ..
......
......@@ -20,12 +20,27 @@ from distutils.version import LooseVersion
import platform
import multiprocessing
from setuptools.command.build_ext import build_ext
import importlib
# Environment variables False/True
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
PACKAGE_NAME = "tilelang"
ROOT_DIR = os.path.dirname(__file__)
def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
envs = load_module_from_path('env', os.path.join(ROOT_DIR, PACKAGE_NAME, 'env.py'))
CUDA_HOME = envs.CUDA_HOME
assert CUDA_HOME, "Failed to automatically detect CUDA installation. Please set the CUDA_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda)."
# TileLang only supports Linux platform
assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)."
......@@ -193,7 +208,7 @@ def build_csrc(llvm_config_path):
# Set LLVM path and enable CUDA in config.cmake
with open("config.cmake", "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
config_file.write("set(USE_CUDA /usr/local/cuda)\n")
config_file.write(f"set(USE_CUDA {CUDA_HOME})\n")
# Run CMake and make
try:
subprocess.check_call(["cmake", ".."])
......@@ -519,7 +534,7 @@ class CMakeBuild(build_ext):
# Here, we set USE_LLVM and USE_CUDA, for example.
with open(dst_config_cmake, "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
config_file.write("set(USE_CUDA /usr/local/cuda)\n")
config_file.write(f"set(USE_CUDA {CUDA_HOME})\n")
# Run CMake to configure the project with the given arguments.
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
......
......@@ -8,6 +8,7 @@ from __future__ import absolute_import as _abs
import os
import subprocess
import warnings
from ..env import CUDA_HOME
import tvm._ffi
from tvm.target import Target
......@@ -132,18 +133,11 @@ def find_cuda_path():
path : str
Path to cuda root.
"""
if "CUDA_PATH" in os.environ:
return os.environ["CUDA_PATH"]
cmd = ["which", "nvcc"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
out = py_str(out)
if proc.returncode == 0:
return os.path.realpath(os.path.join(str(out).strip(), "../.."))
cuda_path = "/usr/local/cuda"
if os.path.exists(os.path.join(cuda_path, "bin/nvcc")):
return cuda_path
raise RuntimeError("Cannot find cuda path")
if CUDA_HOME:
return CUDA_HOME
raise RuntimeError(
"Failed to automatically detect CUDA installation. Please set the CUDA_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda)."
)
def get_cuda_version(cuda_path=None):
......
......@@ -2,9 +2,38 @@ import sys
import os
import pathlib
import logging
import shutil
import glob
logger = logging.getLogger(__name__)
def _find_cuda_home() -> str:
"""Find the CUDA install path.
Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
"""
# Guess #1
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None and "cuda" in nvcc_path.lower():
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
if sys.platform == 'win32':
cuda_homes = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
cuda_home = '' if len(cuda_homes) == 0 else cuda_homes[0]
else:
cuda_home = '/usr/local/cuda'
if not os.path.exists(cuda_home):
cuda_home = None
return cuda_home if cuda_home is not None else ""
CUDA_HOME = _find_cuda_home()
CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None)
TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None)
......@@ -85,4 +114,5 @@ __all__ = [
"TVM_PYTHON_PATH",
"TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH",
"CUDA_HOME",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment