setup.py 8.43 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
yuguo's avatar
yuguo committed
6
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
yuguo's avatar
yuguo committed
7
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
8

Przemek Tredak's avatar
Przemek Tredak committed
9
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
10
import time
Tim Moon's avatar
Tim Moon committed
11
from pathlib import Path
12
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
13

Tim Moon's avatar
Tim Moon committed
14
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
15
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
16

17
from build_tools.build_ext import CMakeExtension, get_build_ext
18
from build_tools.te_version import te_version
19
from build_tools.utils import (
yuguo's avatar
yuguo committed
20
    rocm_build,
21
    cuda_archs,
22
23
24
25
26
    found_cmake,
    found_ninja,
    found_pybind11,
    get_frameworks,
    install_and_import,
27
    remove_dups,
28
)
Przemek Tredak's avatar
Przemek Tredak committed
29

30
31
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
32
33


34
from setuptools.command.build_ext import build_ext as BuildExtension
35

36
37
os.environ["NVTE_PROJECT_BUILDING"] = "1"

38
39
40
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
41
    install_and_import("pybind11[global]")
42
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
43
44


45
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
46
47
48
49
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
50

Tim Moon's avatar
Tim Moon committed
51

Phuong Nguyen's avatar
Phuong Nguyen committed
52
53
54
55
56
57
58
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
59
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
60
61


62
def setup_common_extension() -> CMakeExtension:
63
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
64
65
    if rocm_build():
        cmake_flags = []
66
67
68
69
70
71
72
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "0"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "0"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
            
yuguo's avatar
yuguo committed
73
74
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
75
76
77
78
79
80
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

81
82
83
84
85
86
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

87
88
89
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

90
91
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
92
93
94
95
96
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
97

98
99
    return CMakeExtension(
        name="transformer_engine",
100
        cmake_path=root_path / Path("transformer_engine/common"),
101
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
102
103
104
105
106
107
108
109
110
111
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
112
113
114
115
116
117
118
119
120
    setup_reqs: List[str] = [
        "nvidia-cuda-runtime-cu12",
        "nvidia-cublas-cu12",
        "nvidia-cudnn-cu12",
        "nvidia-cuda-cccl-cu12",
        "nvidia-cuda-nvcc-cu12",
        "nvidia-nvtx-cu12",
        "nvidia-cuda-nvrtc-cu12",
    ]
121
122
    install_reqs: List[str] = [
        "pydantic",
123
        "importlib-metadata>=1.0",
124
        "packaging",
125
    ]
126
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
127
128
129

    # Requirements that may be installed outside of Python
    if not found_cmake():
130
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
131
    if not found_ninja():
132
133
134
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
135

136
137
138
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
139
            setup_reqs.extend(["torch>=2.1"])
140
            install_reqs.extend(["torch>=2.1"])
yuguo's avatar
yuguo committed
141
142
143
144
            # install_reqs.append(
            #     "nvdlfw-inspect @"
            #     " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
            # )
145
146
            # Blackwell is not supported as of Triton 3.2.0, need custom internal build
            # install_reqs.append("triton")
147
            test_reqs.extend(["numpy", "torchvision"])
148
        if "jax" in frameworks:
149
            setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
150
            install_reqs.extend(["jax", "flax>=0.7.1"])
151
            test_reqs.extend(["numpy"])
152

153
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
154

155

156
157
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
        setup_requires = []
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
        setup_requires, install_requires, test_requires = setup_requirements()
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
193
                )
194
195
196
197
198
199
200
201
202
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
203
                )
204

Tim Moon's avatar
Tim Moon committed
205
206
207
    # Configure package
    setuptools.setup(
        name="transformer_engine",
208
209
        version=__version__,
        packages=setuptools.find_packages(
210
211
212
213
214
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
215
        ),
216
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
217
        description="Transformer acceleration library",
218
219
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
220
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
221
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
222
223
224
225
226
227
228
229
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
230
231
232
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
233
234
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
235
    )