# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Installation script.""" import functools import glob import importlib import os import re import shutil import subprocess import sys import platform from pathlib import Path from importlib.metadata import version as get_version from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union # Needs to stay consistent with .pre-commit-config.yaml config. def min_python_version() -> Tuple[int]: """Minimum supported Python version.""" return (3, 10, 0) def min_python_version_str() -> str: """String representing minimum supported Python version.""" return ".".join(map(str, min_python_version())) if sys.version_info < min_python_version(): raise RuntimeError( f"Transformer Engine requires Python {min_python_version_str()} or newer, " f"but found Python {platform.python_version()}." ) @functools.lru_cache(maxsize=None) def debug_build_enabled() -> bool: """Whether to build with a debug configuration""" return bool(int(os.getenv("NVTE_BUILD_DEBUG", "0"))) @functools.lru_cache(maxsize=None) def get_max_jobs_for_parallel_build() -> int: """Number of parallel jobs for Nina build""" # Default: maximum parallel jobs num_jobs = 0 # Check environment variable if os.getenv("NVTE_BUILD_MAX_JOBS"): num_jobs = int(os.getenv("NVTE_BUILD_MAX_JOBS")) elif os.getenv("MAX_JOBS"): num_jobs = int(os.getenv("MAX_JOBS")) # Check command-line arguments for arg in sys.argv.copy(): if arg.startswith("--parallel="): num_jobs = int(arg.replace("--parallel=", "")) sys.argv.remove(arg) return num_jobs def all_files_in_dir(path, name_extension=None): all_files = [] for dirname, _, names in os.walk(path): for name in names: if name_extension is not None and not name.endswith(f".{name_extension}"): continue all_files.append(Path(dirname, name)) return all_files def remove_dups(_list: List): return list(set(_list)) def found_cmake() -> bool: """ "Check if valid CMake is available CMake 3.18 or newer is required. """ # Check if CMake is available try: _cmake_bin = cmake_bin() except FileNotFoundError: return False # Query CMake for version info output = subprocess.run( [_cmake_bin, "--version"], capture_output=True, check=True, universal_newlines=True, ) match = re.search(r"version\s*([\d.]+)", output.stdout) version = match.group(1).split(".") version = tuple(int(v) for v in version) return version >= (3, 18) def cmake_bin() -> Path: """Get CMake executable Throws FileNotFoundError if not found. """ # Search in CMake Python package _cmake_bin: Optional[Path] = None try: from cmake import CMAKE_BIN_DIR except ImportError: pass else: _cmake_bin = Path(CMAKE_BIN_DIR).resolve() / "cmake" if not _cmake_bin.is_file(): _cmake_bin = None # Search in path if _cmake_bin is None: _cmake_bin = shutil.which("cmake") if _cmake_bin is not None: _cmake_bin = Path(_cmake_bin).resolve() # Return executable if found if _cmake_bin is None: raise FileNotFoundError("Could not find CMake executable") return _cmake_bin def found_ninja() -> bool: """ "Check if Ninja is available""" return shutil.which("ninja") is not None def found_pybind11() -> bool: """ "Check if pybind11 is available""" # Check if Python package is installed try: import pybind11 except ImportError: pass else: return True # Check if CMake can find pybind11 if not found_cmake(): return False try: subprocess.run( [ "cmake", "--find-package", "-DMODE=EXIST", "-DNAME=pybind11", "-DCOMPILER_ID=CXX", "-DLANGUAGE=CXX", ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True, ) except (CalledProcessError, OSError): pass else: return True return False @functools.lru_cache(maxsize=None) def rocm_build() -> bool: """ ROCm build should be performed if: - It is configured with NVTE_USE_ROCM=1 env OR: - HIP compiler is found and CUDA one is not """ if bool(int(os.getenv("NVTE_USE_ROCM", "0"))): return True try: nvcc_path() return False except FileNotFoundError: pass _, hipcc_bin = rocm_path() return hipcc_bin.is_file() @functools.lru_cache(maxsize=None) def rocm_path() -> Tuple[str, str]: """ROCm root path and HIPCC binary path as a tuple""" """If ROCm installation is not specified, use default /opt/dtk path""" if os.getenv("ROCM_PATH"): rocm_home = Path(os.getenv("ROCM_PATH")) hipcc_bin = rocm_home / "bin" / "hipcc" if hipcc_bin is None: hipcc_bin = shutil.which("hipcc") if hipcc_bin is not None: hipcc_bin = Path(hipcc_bin) rocm_home = hipcc_bin.parent.parent if hipcc_bin is None: rocm_home = Path("/opt/dtk/") hipcc_bin = rocm_home / "bin" / "hipcc" return rocm_home, hipcc_bin @functools.lru_cache(maxsize=None) def cuda_toolkit_include_path() -> Tuple[str, str]: """Returns root path for cuda toolkit includes. return `None` if CUDA is not found.""" # Try finding CUDA cuda_home: Optional[Path] = None if cuda_home is None and os.getenv("CUDA_HOME"): # Check in CUDA_HOME cuda_home = Path(os.getenv("CUDA_HOME")) / "include" if cuda_home is None: # Check in NVCC nvcc_bin = shutil.which("nvcc") if nvcc_bin is not None: cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc")) / "include" if cuda_home is None: # Last-ditch guess in /usr/local/cuda if Path("/usr/local/cuda").is_dir(): cuda_home = Path("/usr/local/cuda") / "include" return cuda_home @functools.lru_cache(maxsize=None) def nvcc_path() -> Tuple[str, str]: """Returns the NVCC binary path. Throws FileNotFoundError if NVCC is not found.""" # Try finding NVCC nvcc_bin: Optional[Path] = None if nvcc_bin is None and os.getenv("CUDA_HOME"): # Check in CUDA_HOME cuda_home = Path(os.getenv("CUDA_HOME")) nvcc_bin = cuda_home / "bin" / "nvcc" if nvcc_bin is None: # Check if nvcc is in path nvcc_bin = shutil.which("nvcc") if nvcc_bin is not None: cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc")) nvcc_bin = Path(nvcc_bin) if nvcc_bin is None: # Last-ditch guess in /usr/local/cuda cuda_home = Path("/usr/local/cuda") nvcc_bin = cuda_home / "bin" / "nvcc" if not nvcc_bin.is_file(): raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}") return nvcc_bin @functools.lru_cache(maxsize=None) def get_cuda_include_dirs() -> Tuple[str, str]: """Returns the CUDA header directory.""" # If cuda is installed via toolkit, all necessary headers # are bundled inside the top level cuda directory. if cuda_toolkit_include_path() is not None: return [cuda_toolkit_include_path()] # Use pip wheels to include all headers. try: import nvidia except ModuleNotFoundError as e: raise RuntimeError("CUDA not found.") cuda_root = Path(nvidia.__file__).parent return [ cuda_root / "cuda_nvcc" / "include", cuda_root / "cublas" / "include", cuda_root / "cuda_runtime" / "include", cuda_root / "cudnn" / "include", cuda_root / "cuda_cccl" / "include", cuda_root / "nvtx" / "include", cuda_root / "cuda_nvrtc" / "include", ] @functools.lru_cache(maxsize=None) def cuda_archs() -> str: archs = os.getenv("NVTE_CUDA_ARCHS") if archs is None: version = cuda_version() if version >= (13, 0): archs = "75;80;89;90;100;120" elif version >= (12, 8): archs = "70;80;89;90;100;120" else: archs = "70;80;89;90" return archs def cuda_version() -> Tuple[int, ...]: """CUDA Toolkit version as a (major, minor) tuple. Try to get cuda version by locating the nvcc executable and running nvcc --version. If nvcc is not found, look for the cuda runtime package pip `nvidia-cuda-runtime-cu12` and check pip version. """ try: nvcc_bin = nvcc_path() except FileNotFoundError as e: pass else: output = subprocess.run( [nvcc_bin, "-V"], capture_output=True, check=True, universal_newlines=True, ) match = re.search(r"release\s*([\d.]+)", output.stdout) version = match.group(1).split(".") return tuple(int(v) for v in version) try: version_str = get_version("nvidia-cuda-runtime-cu12") version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit()) return version_tuple except importlib.metadata.PackageNotFoundError: raise RuntimeError("Could neither find NVCC executable nor CUDA runtime Python package.") def get_frameworks() -> List[str]: """DL frameworks to build support for""" _frameworks: List[str] = [] supported_frameworks = ["pytorch", "jax"] # Check environment variable if os.getenv("NVTE_FRAMEWORK"): _frameworks.extend(os.getenv("NVTE_FRAMEWORK").split(",")) # Check command-line arguments for arg in sys.argv.copy(): if arg.startswith("--framework="): _frameworks.extend(arg.replace("--framework=", "").split(",")) sys.argv.remove(arg) if rocm_build(): _requested_frameworks = [framework.lower() for framework in _frameworks] # Detect installed frameworks if not explicitly specified if not _frameworks: try: import torch except ImportError: pass else: _frameworks.append("pytorch") try: import jax except ImportError: pass else: _frameworks.append("jax") # Special framework names if "all" in _frameworks: _frameworks = supported_frameworks.copy() if "none" in _frameworks: _frameworks = [] # Check that frameworks are valid _frameworks = [framework.lower() for framework in _frameworks] for framework in _frameworks: if framework not in supported_frameworks: raise ValueError(f"Transformer Engine does not support framework={framework}") if rocm_build(): _unsupported_frameworks = [] if "pytorch" in _frameworks: try: from torch.utils.cpp_extension import IS_HIP_EXTENSION except ImportError: IS_HIP_EXTENSION=False if not IS_HIP_EXTENSION: if "pytorch" in _requested_frameworks: _unsupported_frameworks.append("pytorch") _frameworks.remove("pytorch") if "jax" in _frameworks: if not any(re.match(r'jax-rocm\d+-plugin', d.metadata['Name']) for d in importlib.metadata.distributions()): try: import jaxlib.rocm #pre JAX 0.4.30 way except ImportError: if "jax" in _requested_frameworks: _unsupported_frameworks.append("jax") _frameworks.remove("jax") if _unsupported_frameworks: raise ValueError(f"ROCm is not supported by requested frameworks: {_unsupported_frameworks}") return _frameworks def copy_common_headers( src_dir: Union[Path, str], dst_dir: Union[Path, str], ) -> None: """Copy headers from core library src_dir should be the transformer_engine directory within the root Transformer Engine repository. All .h and .cuh files within transformer_engine/common are copied into dst_dir. Relative paths are preserved. """ # Find common header files in src dir headers = glob.glob( os.path.join(str(src_dir), "common", "**", "*.h"), recursive=True, ) headers.extend( glob.glob( os.path.join(str(src_dir), "common", "**", "*.cuh"), recursive=True, ) ) headers = [Path(path) for path in headers] # Copy common header files to dst dir src_dir = Path(src_dir) dst_dir = Path(dst_dir) for path in headers: new_path = dst_dir / path.relative_to(src_dir) new_path.parent.mkdir(exist_ok=True, parents=True) shutil.copy(path, new_path) def hipify(base_dir, src_dir, sources, include_dirs): hipify_path = base_dir / "3rdparty" / "hipify_torch" cwd = os.getcwd() os.chdir(hipify_path) from hipify_torch.hipify_python import hipify as do_hipify os.chdir(cwd) hipify_result = do_hipify( project_directory=src_dir, output_directory=src_dir, includes=["*"], ignores=["*/amd_detail/*", "*/aotriton/*", "*/ck_fused_attn/*"], header_include_dirs=include_dirs, custom_map_list=base_dir / "hipify_custom_map.json", extra_files=[], is_pytorch_extension=True, hipify_extra_files_only=False, show_detailed=False) # Because hipify output_directory == project_directory # Original sources list may contain previous hipifying results that ends up with duplicated entries # Keep unique entries only hipified_sources = set() for fname in sources: fname = os.path.abspath(str(fname)) if fname in hipify_result: file_result = hipify_result[fname] if file_result.hipified_path is not None: fname = hipify_result[fname].hipified_path # setup() arguments must *always* be /-separated paths relative to the setup.py directory, # *never* absolute paths hipified_sources.add(os.path.relpath(fname, cwd)) return list(hipified_sources)