setup.py 8.09 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
yuguo's avatar
yuguo committed
6
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
yuguo's avatar
yuguo committed
7
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
8

9
from importlib import metadata
Przemek Tredak's avatar
Przemek Tredak committed
10
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
11
import time
Tim Moon's avatar
Tim Moon committed
12
from pathlib import Path
13
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
14

Tim Moon's avatar
Tim Moon committed
15
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
16
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
17

18
from build_tools.build_ext import CMakeExtension, get_build_ext
19
from build_tools.te_version import te_version
20
from build_tools.utils import (
yuguo's avatar
yuguo committed
21
    rocm_build,
22
    cuda_archs,
23
    cuda_version,
24
    get_frameworks,
25
    remove_dups,
26
)
Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
30
31


32
from setuptools.command.build_ext import build_ext as BuildExtension
33

34
35
os.environ["NVTE_PROJECT_BUILDING"] = "1"

36
37
38
39
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
40
41


42
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
43
44
45
46
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
47

Tim Moon's avatar
Tim Moon committed
48

Phuong Nguyen's avatar
Phuong Nguyen committed
49
50
51
52
53
54
55
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
56
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
57
58


59
def setup_common_extension() -> CMakeExtension:
60
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
61
62
    if rocm_build():
        cmake_flags = []
63
64
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
65
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
66
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
67
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
68
69
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
            
yuguo's avatar
yuguo committed
70
71
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
72
73
74
75
76
77
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

78
79
80
81
82
83
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

84
85
86
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

87
88
89
    if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
        cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
        cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
90
91
            f"nvidia-cublasmp-cu{cuda_version()[0]}"
        ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
92
93
        cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
        nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
94
            f"nvidia-nvshmem-cu{cuda_version()[0]}"
95
96
97
98
        ).locate_file("nvidia/nvshmem")
        cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
        print("CMAKE_FLAGS:", cmake_flags[-2:])

99
100
101
102
103
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

104
105
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
106
107
108
109
110
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
111

112
113
    return CMakeExtension(
        name="transformer_engine",
114
        cmake_path=root_path / Path("transformer_engine/common"),
115
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
116
117
118
    )


119
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
120
121
    """Setup Python dependencies

122
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
123
124
125
    """

    # Common requirements
126
127
    install_reqs: List[str] = [
        "pydantic",
128
        "importlib-metadata>=1.0",
129
        "packaging",
130
    ]
131
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
132

133
134
135
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
136
137
138
139
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
140
        if "jax" in frameworks:
141
142
143
144
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
145

146
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
147

148

149
150
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
169
        install_requires, test_requires = setup_requirements()
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
185
                )
186
187
188
189
190
191
192
193
194
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
195
                )
196

Tim Moon's avatar
Tim Moon committed
197
198
199
    # Configure package
    setuptools.setup(
        name="transformer_engine",
200
201
        version=__version__,
        packages=setuptools.find_packages(
202
203
204
205
206
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
207
        ),
208
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
209
        description="Transformer acceleration library",
210
211
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
212
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
213
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
214
215
        python_requires=">=3.8",
        classifiers=["Programming Language :: Python :: 3"],
Tim Moon's avatar
Tim Moon committed
216
217
        install_requires=install_requires,
        license_files=("LICENSE",),
218
219
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
220
    )