setup.py 7.35 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
8
import time
Tim Moon's avatar
Tim Moon committed
9
from pathlib import Path
10
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
11

Tim Moon's avatar
Tim Moon committed
12
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
13
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
14

15
from build_tools.build_ext import CMakeExtension, get_build_ext
16
from build_tools.te_version import te_version
17
from build_tools.utils import (
18
    cuda_archs,
19
20
21
22
23
    found_cmake,
    found_ninja,
    found_pybind11,
    get_frameworks,
    install_and_import,
24
    remove_dups,
25
    cuda_toolkit_include_path,
26
)
Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
30
31


32
from setuptools.command.build_ext import build_ext as BuildExtension
33

34
35
os.environ["NVTE_PROJECT_BUILDING"] = "1"

36
37
38
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
39
    install_and_import("pybind11[global]")
40
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
41
42


43
CMakeBuildExtension = get_build_ext(BuildExtension)
44
archs = cuda_archs()
45

Tim Moon's avatar
Tim Moon committed
46

Phuong Nguyen's avatar
Phuong Nguyen committed
47
48
49
50
51
52
53
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
54
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
55
56


57
def setup_common_extension() -> CMakeExtension:
58
    """Setup CMake extension for common library"""
59
    cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
60
61
62
63
64
65
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

66
67
68
69
70
71
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

72
73
74
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

75
76
    # Project directory root
    root_path = Path(__file__).resolve().parent
77

78
79
    return CMakeExtension(
        name="transformer_engine",
80
        cmake_path=root_path / Path("transformer_engine/common"),
81
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
82
83
84
85
86
87
88
89
90
91
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
92
93
94
95
96
97
98
99
100
101
102
103
104
    setup_reqs: List[str] = []
    if cuda_toolkit_include_path() is None:
        setup_reqs.extend(
            [
                "nvidia-cuda-runtime-cu12",
                "nvidia-cublas-cu12",
                "nvidia-cudnn-cu12",
                "nvidia-cuda-cccl-cu12",
                "nvidia-cuda-nvcc-cu12",
                "nvidia-nvtx-cu12",
                "nvidia-cuda-nvrtc-cu12",
            ]
        )
105
106
    install_reqs: List[str] = [
        "pydantic",
107
        "importlib-metadata>=1.0",
108
        "packaging",
109
    ]
110
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
111
112
113

    # Requirements that may be installed outside of Python
    if not found_cmake():
114
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
115
    if not found_ninja():
116
117
118
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
119

120
121
122
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
123
            setup_reqs.extend(["torch>=2.1"])
124
            install_reqs.extend(["torch>=2.1"])
125
126
127
128
            install_reqs.append(
                "nvdlfw-inspect @"
                " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
            )
129
130
            # Blackwell is not supported as of Triton 3.2.0, need custom internal build
            # install_reqs.append("triton")
131
            test_reqs.extend(["numpy", "torchvision", "transformers"])
132
        if "jax" in frameworks:
133
            setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
134
            install_reqs.extend(["jax", "flax>=0.7.1"])
135
            test_reqs.extend(["numpy"])
136

137
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
138

139

140
141
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
        setup_requires = []
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
        setup_requires, install_requires, test_requires = setup_requirements()
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
177
                )
178
179
180
181
182
183
184
185
186
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
187
                )
188

Tim Moon's avatar
Tim Moon committed
189
190
191
    # Configure package
    setuptools.setup(
        name="transformer_engine",
192
193
        version=__version__,
        packages=setuptools.find_packages(
194
195
196
197
198
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
199
        ),
200
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
201
        description="Transformer acceleration library",
202
203
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
204
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
205
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
206
207
208
209
210
211
212
213
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
214
215
216
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
217
218
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
219
    )