# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Installation script for TE pytorch extensions.""" # pylint: disable=wrong-import-position,wrong-import-order import sys import os import shutil from pathlib import Path import platform import urllib import setuptools from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from packaging.version import parse try: import torch from torch.utils.cpp_extension import BuildExtension except ImportError as e: raise RuntimeError("This package needs Torch to build.") from e FORCE_BUILD = os.getenv("NVTE_PYTORCH_FORCE_BUILD", "FALSE") == "TRUE" FORCE_CXX11_ABI = os.getenv("NVTE_PYTORCH_FORCE_CXX11_ABI", "FALSE") == "TRUE" SKIP_CUDA_BUILD = os.getenv("NVTE_PYTORCH_SKIP_CUDA_BUILD", "FALSE") == "TRUE" PACKAGE_NAME = "transformer_engine_torch" BASE_WHEEL_URL = ( "https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}" ) # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True current_file_path = Path(__file__).parent.resolve() build_tools_dir = current_file_path.parent.parent / "build_tools" if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir): build_tools_copy = current_file_path / "build_tools" if build_tools_copy.exists(): shutil.rmtree(build_tools_copy) shutil.copytree(build_tools_dir, build_tools_copy) from build_tools.build_ext import get_build_ext from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.pytorch import ( setup_pytorch_extension, install_requirements, test_requirements, ) os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) def get_platform(): """ Returns the platform name as used in wheel filenames. """ if sys.platform.startswith("linux"): return f"linux_{platform.uname().machine}" if sys.platform == "darwin": mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) return f"macosx_{mac_version}_x86_64" if sys.platform == "win32": return "win_amd64" raise ValueError(f"Unsupported platform: {sys.platform}") def get_wheel_url(): """Construct the wheel URL for the current platform.""" torch_version_raw = parse(torch.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() nvte_version = te_version() torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 # to save CI time. Minor versions should be compatible. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" cuda_version = f"{torch_cuda_version.major}" # Determine wheel URL based on CUDA version, torch version, python version and OS wheel_filename = f"{PACKAGE_NAME}-{nvte_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{nvte_version}", wheel_name=wheel_filename) return wheel_url, wheel_filename class CachedWheelsCommand(_bdist_wheel): """ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot find an existing wheel (which is currently the case for all grouped gemm installs). We use the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ def run(self): if FORCE_BUILD: super().run() wheel_url, wheel_filename = get_wheel_url() print("Guessing wheel URL: ", wheel_url) try: urllib.request.urlretrieve(wheel_url, wheel_filename) # Make the archive # Lifted from the root wheel processing command # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 if not os.path.exists(self.dist_dir): os.makedirs(self.dist_dir) impl_tag, abi_tag, plat_tag = self.get_tag() archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) os.rename(wheel_filename, wheel_path) except (urllib.error.HTTPError, urllib.error.URLError): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source super().run() if __name__ == "__main__": # Extensions common_headers_dir = "common_headers" copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) ext_modules = [ setup_pytorch_extension( "csrc", current_file_path / "csrc", current_file_path / common_headers_dir ) ] # Configure package setuptools.setup( name=PACKAGE_NAME, version=te_version(), description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, python_requires=f">={min_python_version_str()}", install_requires=install_requirements(), tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) shutil.rmtree("build_tools")