setup.py 8.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
yuguo's avatar
yuguo committed
6
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
yuguo's avatar
yuguo committed
7
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
8

9
from importlib import metadata
Przemek Tredak's avatar
Przemek Tredak committed
10
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
11
import time
Tim Moon's avatar
Tim Moon committed
12
from pathlib import Path
13
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
14

Tim Moon's avatar
Tim Moon committed
15
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
16
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
17

18
from build_tools.build_ext import CMakeExtension, get_build_ext
19
from build_tools.te_version import te_version
20
from build_tools.utils import (
yuguo's avatar
yuguo committed
21
    rocm_build,
22
    cuda_archs,
23
    cuda_version,
24
    get_frameworks,
25
    remove_dups,
26
    min_python_version_str,
27
)
Przemek Tredak's avatar
Przemek Tredak committed
28

29
30
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
31
32


33
from setuptools.command.build_ext import build_ext as BuildExtension
34

35
36
os.environ["NVTE_PROJECT_BUILDING"] = "1"

37
38
39
40
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
41
42


43
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
44
45
46
47
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
48

Tim Moon's avatar
Tim Moon committed
49

Phuong Nguyen's avatar
Phuong Nguyen committed
50
51
52
53
54
55
56
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
57
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
58
59


60
def setup_common_extension() -> CMakeExtension:
61
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
62
63
    if rocm_build():
        cmake_flags = []
64
65
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
66
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
67
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
68
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
69
70
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
            
yuguo's avatar
yuguo committed
71
72
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
73
74
75
76
77
78
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

79
80
81
82
83
84
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

85
86
87
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

88
89
90
    if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
        cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
        cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
91
92
            f"nvidia-cublasmp-cu{cuda_version()[0]}"
        ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
93
94
        cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
        nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
95
            f"nvidia-nvshmem-cu{cuda_version()[0]}"
96
97
98
99
        ).locate_file("nvidia/nvshmem")
        cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
        print("CMAKE_FLAGS:", cmake_flags[-2:])

100
101
102
103
104
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

105
106
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
107
108
109
110
111
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
112

113
114
    return CMakeExtension(
        name="transformer_engine",
115
        cmake_path=root_path / Path("transformer_engine/common"),
116
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
117
118
119
    )


120
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
121
122
    """Setup Python dependencies

123
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
124
125
126
    """

    # Common requirements
127
128
    install_reqs: List[str] = [
        "pydantic",
129
        "importlib-metadata>=1.0",
130
        "packaging",
131
    ]
132
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
133

134
135
136
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
137
138
139
140
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
141
        if "jax" in frameworks:
142
143
144
145
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
146

147
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
148

149

150
151
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
152

153
154
155
156
157
158
159
160
161
162
163
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
164
        install_requires = []
165
        extras_require = {
166
167
168
            "core": [f"transformer_engine_cu12=={__version__}"],
            "core_cu12": [f"transformer_engine_cu12=={__version__}"],
            "core_cu13": [f"transformer_engine_cu13=={__version__}"],
169
170
171
172
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
173
        install_requires, test_requires = setup_requirements()
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
189
                )
190
191
192
193
194
195
196
197
198
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
199
                )
200

Tim Moon's avatar
Tim Moon committed
201
202
203
    # Configure package
    setuptools.setup(
        name="transformer_engine",
204
205
        version=__version__,
        packages=setuptools.find_packages(
206
207
208
209
210
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
211
        ),
212
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
213
        description="Transformer acceleration library",
214
215
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
216
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
217
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
218
        python_requires=f">={min_python_version_str()}",
219
        classifiers=["Programming Language :: Python :: 3"],
Tim Moon's avatar
Tim Moon committed
220
221
        install_requires=install_requires,
        license_files=("LICENSE",),
222
223
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
224
    )