setup.py 8.66 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
wenjh's avatar
wenjh committed
6
7
8
9
# 编译并安装命令
# NVTE_BUILD_SUPPRESS_UNUSED_WARNING=1 NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=1 NVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=1 NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=0 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py install -v
# 打WHL包命令
# NVTE_BUILD_SUPPRESS_UNUSED_WARNING=1 NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=1 NVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=1 NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=0 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
10

11
from importlib import metadata
Przemek Tredak's avatar
Przemek Tredak committed
12
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
13
import time
Tim Moon's avatar
Tim Moon committed
14
from pathlib import Path
15
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
16

Tim Moon's avatar
Tim Moon committed
17
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
18
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
19

20
from build_tools.build_ext import CMakeExtension, get_build_ext
21
from build_tools.te_version import te_version
22
from build_tools.utils import (
yuguo's avatar
yuguo committed
23
    rocm_build,
24
    cuda_archs,
25
    cuda_version,
26
    get_frameworks,
27
    remove_dups,
28
    min_python_version_str,
29
)
Przemek Tredak's avatar
Przemek Tredak committed
30

31
32
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
33
34


35
from setuptools.command.build_ext import build_ext as BuildExtension
36

37
38
os.environ["NVTE_PROJECT_BUILDING"] = "1"

39
40
41
42
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
43
44


45
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
46
47
48
49
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
50

Tim Moon's avatar
Tim Moon committed
51

Phuong Nguyen's avatar
Phuong Nguyen committed
52
53
54
55
56
57
58
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
59
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
60
61


62
def setup_common_extension() -> CMakeExtension:
63
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
64
65
    if rocm_build():
        cmake_flags = []
66
67
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
68
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
69
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
70
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
71
72
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
            
yuguo's avatar
yuguo committed
73
74
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
75
76
77
78
79
80
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

81
82
83
84
85
86
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

87
88
89
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

90
91
92
    if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
        cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
        cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
93
94
            f"nvidia-cublasmp-cu{cuda_version()[0]}"
        ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
95
96
        cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
        nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
97
            f"nvidia-nvshmem-cu{cuda_version()[0]}"
98
99
100
101
        ).locate_file("nvidia/nvshmem")
        cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
        print("CMAKE_FLAGS:", cmake_flags[-2:])

102
103
104
105
106
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

107
108
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
109
110
111
112
113
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
114

115
116
    return CMakeExtension(
        name="transformer_engine",
117
        cmake_path=root_path / Path("transformer_engine/common"),
118
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
119
120
121
    )


122
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
123
124
    """Setup Python dependencies

125
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
126
127
128
    """

    # Common requirements
129
130
    install_reqs: List[str] = [
        "pydantic",
131
        "importlib-metadata>=1.0",
132
        "packaging",
133
    ]
134
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
135

136
137
138
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
139
140
141
142
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
143
        if "jax" in frameworks:
144
145
146
147
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
148

149
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
150

151

152
153
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
154

155
156
157
158
159
160
161
162
163
164
165
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
166
        install_requires = []
167
        extras_require = {
168
169
170
            "core": [f"transformer_engine_cu12=={__version__}"],
            "core_cu12": [f"transformer_engine_cu12=={__version__}"],
            "core_cu13": [f"transformer_engine_cu13=={__version__}"],
171
172
173
174
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
175
        install_requires, test_requires = setup_requirements()
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
191
                )
192
193
194
195
196
197
198
199
200
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
201
                )
202

Tim Moon's avatar
Tim Moon committed
203
204
205
    # Configure package
    setuptools.setup(
        name="transformer_engine",
206
207
        version=__version__,
        packages=setuptools.find_packages(
208
209
210
211
212
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
213
        ),
214
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
215
        description="Transformer acceleration library",
216
217
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
218
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
219
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
220
        python_requires=f">={min_python_version_str()}",
221
        classifiers=["Programming Language :: Python :: 3"],
Tim Moon's avatar
Tim Moon committed
222
223
        install_requires=install_requires,
        license_files=("LICENSE",),
224
225
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
226
    )