setup.py 7.33 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
yuguo's avatar
yuguo committed
6
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=0 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
7

Przemek Tredak's avatar
Przemek Tredak committed
8
import os
9
import sys
Phuong Nguyen's avatar
Phuong Nguyen committed
10
import time
Tim Moon's avatar
Tim Moon committed
11
from pathlib import Path
12
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
13

Tim Moon's avatar
Tim Moon committed
14
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
15
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
16

17
from build_tools.build_ext import CMakeExtension, get_build_ext
18
from build_tools.te_version import te_version
19
from build_tools.utils import (
yuguo's avatar
yuguo committed
20
    rocm_build,
21
    cuda_archs,
22
23
24
25
26
    found_cmake,
    found_ninja,
    found_pybind11,
    get_frameworks,
    install_and_import,
27
    remove_dups,
28
    uninstall_te_wheel_packages,
29
)
Przemek Tredak's avatar
Przemek Tredak committed
30

31
32
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
33
34


35
from setuptools.command.build_ext import build_ext as BuildExtension
36

37
38
os.environ["NVTE_PROJECT_BUILDING"] = "1"

39
40
41
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
42
    install_and_import("pybind11[global]")
43
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
44
45


46
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
47
48
49
50
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
51

Tim Moon's avatar
Tim Moon committed
52

Phuong Nguyen's avatar
Phuong Nguyen committed
53
54
55
56
57
58
59
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
60
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
61
62


63
def setup_common_extension() -> CMakeExtension:
64
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
65
66
67
68
    if rocm_build():
        cmake_flags = []
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
69
70
71
72
73
74
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

75
76
77
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

78
79
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
80
81
82
83
84
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
85

86
87
    return CMakeExtension(
        name="transformer_engine",
88
        cmake_path=root_path / Path("transformer_engine/common"),
89
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
90
91
92
93
94
95
96
97
98
99
100
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
101
102
    install_reqs: List[str] = [
        "pydantic",
103
        "importlib-metadata>=1.0",
104
        "packaging",
105
    ]
106
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
107
108
109

    # Requirements that may be installed outside of Python
    if not found_cmake():
110
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
111
    if not found_ninja():
112
113
114
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
115

116
117
118
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
119
            install_reqs.extend(["torch>=2.1"])
120
121
            # Blackwell is not supported as of Triton 3.2.0, need custom internal build
            # install_reqs.append("triton")
122
            test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
123
124
        if "jax" in frameworks:
            install_reqs.extend(["jax", "flax>=0.7.1"])
125
126
            # test_reqs.extend(["numpy", "praxis"])
            test_reqs.extend(["numpy"])
127

128
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
129

130

131
132
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        cmdclass = {}
        package_data = {}
        include_package_data = False
        setup_requires = []
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
        setup_requires, install_requires, test_requires = setup_requirements()
        ext_modules = [setup_common_extension()]
        cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            # Remove residual FW packages since compiling from source
            # results in a single binary with FW extensions included.
            uninstall_te_wheel_packages()
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
173
                )
174
175
176
177
178
179
180
181
182
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
183
                )
184

Tim Moon's avatar
Tim Moon committed
185
186
187
    # Configure package
    setuptools.setup(
        name="transformer_engine",
188
189
        version=__version__,
        packages=setuptools.find_packages(
190
191
192
193
194
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
195
        ),
196
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
197
        description="Transformer acceleration library",
198
199
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
200
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
201
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
202
203
204
205
206
207
208
209
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
210
211
212
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
213
214
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
215
    )