setup.py 6.83 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
#
# See LICENSE for license information.

"""Installation script for TE pytorch extensions."""

# pylint: disable=wrong-import-position,wrong-import-order

import sys
import os
import shutil
from pathlib import Path
13
14
import platform
import urllib
15
import setuptools
16
17
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from packaging.version import parse
18
19

try:
20
    import torch
21
    from torch.utils.cpp_extension import BuildExtension
22
23
24
except ImportError as e:
    raise RuntimeError("This package needs Torch to build.") from e

25
26
27
28
29
30
31
32
33
34
35
36
FORCE_BUILD = os.getenv("NVTE_PYTORCH_FORCE_BUILD", "FALSE") == "TRUE"
FORCE_CXX11_ABI = os.getenv("NVTE_PYTORCH_FORCE_CXX11_ABI", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("NVTE_PYTORCH_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
PACKAGE_NAME = "transformer_engine_torch"
BASE_WHEEL_URL = (
    "https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}"
)
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
    torch._C._GLIBCXX_USE_CXX11_ABI = True
37
38
39
40

current_file_path = Path(__file__).parent.resolve()
build_tools_dir = current_file_path.parent.parent / "build_tools"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir):
41
42
43
44
    build_tools_copy = current_file_path / "build_tools"
    if build_tools_copy.exists():
        shutil.rmtree(build_tools_copy)
    shutil.copytree(build_tools_dir, build_tools_copy)
45
46
47


from build_tools.build_ext import get_build_ext
48
from build_tools.utils import copy_common_headers, min_python_version_str
49
from build_tools.te_version import te_version
50
51
52
53
54
from build_tools.pytorch import (
    setup_pytorch_extension,
    install_requirements,
    test_requirements,
)
55
56


57
os.environ["NVTE_PROJECT_BUILDING"] = "1"
58
CMakeBuildExtension = get_build_ext(BuildExtension, True)
59
60


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def get_platform():
    """
    Returns the platform name as used in wheel filenames.
    """
    if sys.platform.startswith("linux"):
        return f"linux_{platform.uname().machine}"
    if sys.platform == "darwin":
        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
        return f"macosx_{mac_version}_x86_64"
    if sys.platform == "win32":
        return "win_amd64"

    raise ValueError(f"Unsupported platform: {sys.platform}")


def get_wheel_url():
    """Construct the wheel URL for the current platform."""
    torch_version_raw = parse(torch.__version__)
    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
    platform_name = get_platform()
    nvte_version = te_version()
    torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()

    # Determine the version numbers that will be used to determine the correct wheel
    # We're using the CUDA version used to build torch, not the one currently installed
    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
    torch_cuda_version = parse(torch.version.cuda)
    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
    # to save CI time. Minor versions should be compatible.
    torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
    # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
    cuda_version = f"{torch_cuda_version.major}"

    # Determine wheel URL based on CUDA version, torch version, python version and OS
    wheel_filename = f"{PACKAGE_NAME}-{nvte_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"

    wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{nvte_version}", wheel_name=wheel_filename)

    return wheel_url, wheel_filename


class CachedWheelsCommand(_bdist_wheel):
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all grouped gemm installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """

    def run(self):
        if FORCE_BUILD:
            super().run()

        wheel_url, wheel_filename = get_wheel_url()
        print("Guessing wheel URL: ", wheel_url)
        try:
            urllib.request.urlretrieve(wheel_url, wheel_filename)

            # Make the archive
            # Lifted from the root wheel processing command
            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
            if not os.path.exists(self.dist_dir):
                os.makedirs(self.dist_dir)

            impl_tag, abi_tag, plat_tag = self.get_tag()
            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"

            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
            print("Raw wheel path", wheel_path)
            os.rename(wheel_filename, wheel_path)
        except (urllib.error.HTTPError, urllib.error.URLError):
            print("Precompiled wheel not found. Building from source...")
            # If the wheel could not be downloaded, build from source
            super().run()


138
139
140
if __name__ == "__main__":
    # Extensions
    common_headers_dir = "common_headers"
141
    copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
142
143
    ext_modules = [
        setup_pytorch_extension(
144
145
146
            "csrc", current_file_path / "csrc", current_file_path / common_headers_dir
        )
    ]
147

148
149
150
151
152
153
154
155
156
157
    # Setup version and requirements.
    # Having the framework extension depend on the core lib allows
    # us to detect CUDA version dynamically during compilation and
    # choose the correct wheel for te core lib.
    __version__ = te_version()
    cuda_major_version = parse(torch.version.cuda).major
    assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}."
    te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}"
    install_requires = install_requirements() + [te_core]

158
159
    # Configure package
    setuptools.setup(
160
        name=PACKAGE_NAME,
161
        version=__version__,
162
163
        description="Transformer acceleration library - Torch Lib",
        ext_modules=ext_modules,
164
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
165
        python_requires=f">={min_python_version_str()}",
166
        install_requires=install_requires,
167
        tests_require=test_requirements(),
168
169
170
    )
    if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
        shutil.rmtree(common_headers_dir)
171
        shutil.rmtree("build_tools")