setup.py 8.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

7
from importlib import metadata
Przemek Tredak's avatar
Przemek Tredak committed
8
import os
9
10
import shutil
import subprocess
Phuong Nguyen's avatar
Phuong Nguyen committed
11
import time
Tim Moon's avatar
Tim Moon committed
12
from pathlib import Path
13
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
14

Tim Moon's avatar
Tim Moon committed
15
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
16
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
17

18
from build_tools.build_ext import CMakeExtension, get_build_ext
19
from build_tools.te_version import te_version
20
from build_tools.utils import (
21
    cuda_archs,
22
    cuda_version,
23
    get_frameworks,
24
    remove_dups,
25
    min_python_version_str,
26
)
Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
30
31


32
from setuptools.command.build_ext import build_ext as BuildExtension
33

34
35
os.environ["NVTE_PROJECT_BUILDING"] = "1"

36
37
38
39
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
40
41


42
CMakeBuildExtension = get_build_ext(BuildExtension)
43
archs = cuda_archs()
44

Tim Moon's avatar
Tim Moon committed
45

Phuong Nguyen's avatar
Phuong Nguyen committed
46
47
48
49
50
51
52
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
53
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
54
55


56
def setup_common_extension() -> CMakeExtension:
57
    """Setup CMake extension for common library"""
58
    cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
59
60
61
62
63
64
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

65
66
67
68
69
70
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

71
72
73
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

74
75
76
    if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
        cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
        cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
77
78
            f"nvidia-cublasmp-cu{cuda_version()[0]}"
        ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
79
80
        cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
        nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
81
            f"nvidia-nvshmem-cu{cuda_version()[0]}"
82
83
84
85
        ).locate_file("nvidia/nvshmem")
        cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
        print("CMAKE_FLAGS:", cmake_flags[-2:])

86
87
88
89
90
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

91
92
    # Project directory root
    root_path = Path(__file__).resolve().parent
93

94
95
    return CMakeExtension(
        name="transformer_engine",
96
        cmake_path=root_path / Path("transformer_engine/common"),
97
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
98
99
100
    )


101
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
102
103
    """Setup Python dependencies

104
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
105
106
107
    """

    # Common requirements
108
109
    install_reqs: List[str] = [
        "pydantic",
110
        "importlib-metadata>=1.0",
111
        "packaging",
112
    ]
113
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
114

115
116
117
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
118
119
120
121
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
122
        if "jax" in frameworks:
123
124
125
126
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
127

128
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
129

130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def git_check_submodules() -> None:
    """
    Attempt to checkout git submodules automatically during setup.

    This runs successfully only if the submodules are
    either in the correct or uninitialized state.

    Note to devs: With this, any updates to the submodules itself, e.g. moving to a newer
    commit, must be commited before build. This also ensures that stale submodules aren't
    being silently used by developers.
    """

    # Provide an option to skip these checks for development.
    if bool(int(os.getenv("NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD", "0"))):
        return

    # Require git executable.
    if shutil.which("git") is None:
        return

    # Require a .gitmodules file.
    if not (current_file_path / ".gitmodules").exists():
        return

    try:
        submodules = subprocess.check_output(
            ["git", "submodule", "status", "--recursive"],
            cwd=str(current_file_path),
            text=True,
        ).splitlines()

        for submodule in submodules:
            # '-' start is for an uninitialized submodule.
            # ' ' start is for a submodule on the correct commit.
            assert submodule[0] in (
                " ",
                "-",
            ), (
                "Submodules are initialized incorrectly. If this is intended, set the "
                "environment variable `NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD` to a "
                "non-zero value to skip these checks during development. Otherwise, "
                "run `git submodule update --init --recursive` to checkout the correct"
                " submodule commits."
            )

        subprocess.check_call(
            ["git", "submodule", "update", "--init", "--recursive"],
            cwd=str(current_file_path),
        )
    except subprocess.CalledProcessError:
        return


184
185
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
186

187
188
    git_check_submodules()

189
190
191
192
193
194
195
196
197
198
199
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
200
        install_requires = []
201
        extras_require = {
202
203
204
            "core": [f"transformer_engine_cu12=={__version__}"],
            "core_cu12": [f"transformer_engine_cu12=={__version__}"],
            "core_cu13": [f"transformer_engine_cu13=={__version__}"],
205
206
207
208
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
209
        install_requires, test_requires = setup_requirements()
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
225
                )
226
227
228
229
230
231
232
233
234
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
235
                )
236

Tim Moon's avatar
Tim Moon committed
237
238
239
    # Configure package
    setuptools.setup(
        name="transformer_engine",
240
241
        version=__version__,
        packages=setuptools.find_packages(
242
243
244
245
246
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
247
        ),
248
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
249
        description="Transformer acceleration library",
250
251
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
252
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
253
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
254
        python_requires=f">={min_python_version_str()}",
255
        classifiers=["Programming Language :: Python :: 3"],
Tim Moon's avatar
Tim Moon committed
256
257
        install_requires=install_requires,
        license_files=("LICENSE",),
258
259
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
260
    )