setup.py 8.06 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
yuguo's avatar
yuguo committed
6
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
yuguo's avatar
yuguo committed
7
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
8

Przemek Tredak's avatar
Przemek Tredak committed
9
import os
10
import sys
Phuong Nguyen's avatar
Phuong Nguyen committed
11
import time
Tim Moon's avatar
Tim Moon committed
12
from pathlib import Path
13
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
14

Tim Moon's avatar
Tim Moon committed
15
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
16
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
17

18
from build_tools.build_ext import CMakeExtension, get_build_ext
19
from build_tools.te_version import te_version
20
from build_tools.utils import (
yuguo's avatar
yuguo committed
21
    rocm_build,
22
    cuda_archs,
23
24
25
26
27
    found_cmake,
    found_ninja,
    found_pybind11,
    get_frameworks,
    install_and_import,
28
    remove_dups,
29
    uninstall_te_wheel_packages,
30
)
Przemek Tredak's avatar
Przemek Tredak committed
31

32
33
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
34
35


36
from setuptools.command.build_ext import build_ext as BuildExtension
37

38
39
os.environ["NVTE_PROJECT_BUILDING"] = "1"

40
41
42
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
43
    install_and_import("pybind11[global]")
44
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
45
46


47
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
48
49
50
51
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
52

Tim Moon's avatar
Tim Moon committed
53

Phuong Nguyen's avatar
Phuong Nguyen committed
54
55
56
57
58
59
60
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
61
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
62
63


64
def setup_common_extension() -> CMakeExtension:
65
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
66
67
    if rocm_build():
        cmake_flags = []
68
69
70
71
72
73
74
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "0"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "0"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
            
yuguo's avatar
yuguo committed
75
76
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
77
78
79
80
81
82
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

83
84
85
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

86
87
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
88
89
90
91
92
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
93

94
95
    return CMakeExtension(
        name="transformer_engine",
96
        cmake_path=root_path / Path("transformer_engine/common"),
97
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
98
99
100
101
102
103
104
105
106
107
108
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
109
110
    install_reqs: List[str] = [
        "pydantic",
111
        "importlib-metadata>=1.0",
112
        "packaging",
113
    ]
114
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
115
116
117

    # Requirements that may be installed outside of Python
    if not found_cmake():
118
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
119
    if not found_ninja():
120
121
122
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
123

124
125
126
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
127
            install_reqs.extend(["torch>=2.1"])
128
129
            # Blackwell is not supported as of Triton 3.2.0, need custom internal build
            # install_reqs.append("triton")
130
            test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
131
132
        if "jax" in frameworks:
            install_reqs.extend(["jax", "flax>=0.7.1"])
133
134
            # test_reqs.extend(["numpy", "praxis"])
            test_reqs.extend(["numpy"])
135

136
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
137

138

139
140
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
141

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        cmdclass = {}
        package_data = {}
        include_package_data = False
        setup_requires = []
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
        setup_requires, install_requires, test_requires = setup_requirements()
        ext_modules = [setup_common_extension()]
        cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            # Remove residual FW packages since compiling from source
            # results in a single binary with FW extensions included.
            uninstall_te_wheel_packages()
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
181
                )
182
183
184
185
186
187
188
189
190
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
191
                )
192

Tim Moon's avatar
Tim Moon committed
193
194
195
    # Configure package
    setuptools.setup(
        name="transformer_engine",
196
197
        version=__version__,
        packages=setuptools.find_packages(
198
199
200
201
202
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
203
        ),
204
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
205
        description="Transformer acceleration library",
206
207
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
208
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
209
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
210
211
212
213
214
215
216
217
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
218
219
220
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
221
222
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
223
    )