setup.py 7.22 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
8
import sys
Phuong Nguyen's avatar
Phuong Nguyen committed
9
import time
Tim Moon's avatar
Tim Moon committed
10
from pathlib import Path
11
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
12

Tim Moon's avatar
Tim Moon committed
13
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
14
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
15

16
from build_tools.build_ext import CMakeExtension, get_build_ext
17
from build_tools.te_version import te_version
18
from build_tools.utils import (
19
    cuda_archs,
20
21
22
23
24
    found_cmake,
    found_ninja,
    found_pybind11,
    get_frameworks,
    install_and_import,
25
    remove_dups,
26
    uninstall_te_wheel_packages,
27
)
Przemek Tredak's avatar
Przemek Tredak committed
28

29
30
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
31
32


33
from setuptools.command.build_ext import build_ext as BuildExtension
34

35
36
os.environ["NVTE_PROJECT_BUILDING"] = "1"

37
38
39
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
40
    install_and_import("pybind11[global]")
41
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
42
43


44
CMakeBuildExtension = get_build_ext(BuildExtension)
45
archs = cuda_archs()
46

Tim Moon's avatar
Tim Moon committed
47

Phuong Nguyen's avatar
Phuong Nguyen committed
48
49
50
51
52
53
54
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
55
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
56
57


58
def setup_common_extension() -> CMakeExtension:
59
    """Setup CMake extension for common library"""
60
    cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
61
62
63
64
65
66
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

67
68
69
70
71
72
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

73
74
75
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

76
77
    # Project directory root
    root_path = Path(__file__).resolve().parent
78

79
80
    return CMakeExtension(
        name="transformer_engine",
81
        cmake_path=root_path / Path("transformer_engine/common"),
82
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
83
84
85
86
87
88
89
90
91
92
93
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
94
95
    install_reqs: List[str] = [
        "pydantic",
96
        "importlib-metadata>=1.0",
97
        "packaging",
98
    ]
99
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
100
101
102

    # Requirements that may be installed outside of Python
    if not found_cmake():
103
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
104
    if not found_ninja():
105
106
107
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
108

109
110
111
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
112
            install_reqs.extend(["torch>=2.1"])
113
114
115
116
            install_reqs.append(
                "nvdlfw-inspect @"
                " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
            )
117
118
            # Blackwell is not supported as of Triton 3.2.0, need custom internal build
            # install_reqs.append("triton")
119
            test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
120
121
        if "jax" in frameworks:
            install_reqs.extend(["jax", "flax>=0.7.1"])
122
123
            # test_reqs.extend(["numpy", "praxis"])
            test_reqs.extend(["numpy"])
124

125
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
126

127

128
129
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        cmdclass = {}
        package_data = {}
        include_package_data = False
        setup_requires = []
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
        setup_requires, install_requires, test_requires = setup_requirements()
        ext_modules = [setup_common_extension()]
        cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            # Remove residual FW packages since compiling from source
            # results in a single binary with FW extensions included.
            uninstall_te_wheel_packages()
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
170
                )
171
172
173
174
175
176
177
178
179
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
180
                )
181

Tim Moon's avatar
Tim Moon committed
182
183
184
    # Configure package
    setuptools.setup(
        name="transformer_engine",
185
186
        version=__version__,
        packages=setuptools.find_packages(
187
188
189
190
191
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
192
        ),
193
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
194
        description="Transformer acceleration library",
195
196
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
197
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
198
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
199
200
201
202
203
204
205
206
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
207
208
209
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
210
211
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
212
    )