setup.py 13 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
wenjh's avatar
wenjh committed
6
# 编译并安装命令
wenjh's avatar
wenjh committed
7
# NVTE_BUILD_SUPPRESS_UNUSED_WARNING=1 NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=1 NVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=1 NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=0 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH pip install --no-build-isolation . -v
wenjh's avatar
wenjh committed
8
9
# 打WHL包命令
# NVTE_BUILD_SUPPRESS_UNUSED_WARNING=1 NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=1 NVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=1 NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=0 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
10

11
from importlib import metadata
Przemek Tredak's avatar
Przemek Tredak committed
12
import os
13
14
import shutil
import subprocess
Phuong Nguyen's avatar
Phuong Nguyen committed
15
import time
Tim Moon's avatar
Tim Moon committed
16
from pathlib import Path
17
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
18

Tim Moon's avatar
Tim Moon committed
19
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
20
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
21

22
from build_tools.build_ext import CMakeExtension, get_build_ext
23
from build_tools.te_version import te_version
24
from build_tools.utils import (
yuguo's avatar
yuguo committed
25
    rocm_build,
26
    cuda_archs,
27
    cuda_version,
28
    get_frameworks,
29
    remove_dups,
30
    min_python_version_str,
31
)
Przemek Tredak's avatar
Przemek Tredak committed
32

33
34
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
35
36


37
from setuptools.command.build_ext import build_ext as BuildExtension
38

39
40
os.environ["NVTE_PROJECT_BUILDING"] = "1"

41
42
43
44
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
45
46


47
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
48
49
50
51
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
52

wenjh's avatar
wenjh committed
53
if bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
wenjh's avatar
wenjh committed
54
    common_dir = current_file_path / "transformer_engine" / "common"
wenjh's avatar
wenjh committed
55
    common_copy = current_file_path / "transformer_engine_hygon" / "common"
wenjh's avatar
wenjh committed
56
57
58
59
    if common_copy.exists():
        shutil.rmtree(common_copy)
    shutil.copytree(common_dir, common_copy)
    csrc_dir = current_file_path / "transformer_engine" / "pytorch" / "csrc"
wenjh's avatar
wenjh committed
60
    csrc_copy = current_file_path / "transformer_engine_hygon" / "pytorch" / "csrc"
wenjh's avatar
wenjh committed
61
62
63
    if csrc_copy.exists():
        shutil.rmtree(csrc_copy)
    shutil.copytree(csrc_dir, csrc_copy)
Tim Moon's avatar
Tim Moon committed
64

Phuong Nguyen's avatar
Phuong Nguyen committed
65
66
67
68
69
70
71
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
72
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
73
74


75
def setup_common_extension() -> CMakeExtension:
76
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
77
78
    if rocm_build():
        cmake_flags = []
79
80
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
81
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
82
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
83
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
84
85
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
            
yuguo's avatar
yuguo committed
86
87
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
88
89
90
91
92
93
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

94
95
96
97
98
99
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

100
101
102
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

103
104
105
    if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
        cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
        cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
106
107
            f"nvidia-cublasmp-cu{cuda_version()[0]}"
        ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
108
109
        cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
        nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
110
            f"nvidia-nvshmem-cu{cuda_version()[0]}"
111
112
113
114
        ).locate_file("nvidia/nvshmem")
        cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
        print("CMAKE_FLAGS:", cmake_flags[-2:])

115
116
117
118
119
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

120
121
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
122
123
124
125
126
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
127

wenjh's avatar
wenjh committed
128
    if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
wenjh's avatar
wenjh committed
129
130
        cmake_path = root_path / Path("transformer_engine/common")
    else:
wenjh's avatar
wenjh committed
131
        cmake_path = root_path / Path("transformer_engine_hygon/common")
132
    return CMakeExtension(
wenjh's avatar
wenjh committed
133
        name="transformer_engine" if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))) else "transformer_engine_hygon",
wenjh's avatar
wenjh committed
134
        cmake_path=cmake_path,
135
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
136
137
138
    )


139
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
140
141
    """Setup Python dependencies

142
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
143
144
145
    """

    # Common requirements
146
147
    install_reqs: List[str] = [
        "pydantic",
148
        "importlib-metadata>=1.0",
149
        "packaging",
150
    ]
151
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
152

153
154
155
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
156
157
158
159
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
160
        if "jax" in frameworks:
161
162
163
164
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
165

166
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
167

168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def git_check_submodules() -> None:
    """
    Attempt to checkout git submodules automatically during setup.

    This runs successfully only if the submodules are
    either in the correct or uninitialized state.

    Note to devs: With this, any updates to the submodules itself, e.g. moving to a newer
    commit, must be commited before build. This also ensures that stale submodules aren't
    being silently used by developers.
    """

    # Provide an option to skip these checks for development.
    if bool(int(os.getenv("NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD", "0"))):
        return

    # Require git executable.
    if shutil.which("git") is None:
        return

    # Require a .gitmodules file.
    if not (current_file_path / ".gitmodules").exists():
        return

    try:
        submodules = subprocess.check_output(
            ["git", "submodule", "status", "--recursive"],
            cwd=str(current_file_path),
            text=True,
        ).splitlines()

        for submodule in submodules:
            # '-' start is for an uninitialized submodule.
            # ' ' start is for a submodule on the correct commit.
            assert submodule[0] in (
                " ",
                "-",
            ), (
                "Submodules are initialized incorrectly. If this is intended, set the "
                "environment variable `NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD` to a "
                "non-zero value to skip these checks during development. Otherwise, "
                "run `git submodule update --init --recursive` to checkout the correct"
                " submodule commits."
            )

        subprocess.check_call(
            ["git", "submodule", "update", "--init", "--recursive"],
            cwd=str(current_file_path),
        )
    except subprocess.CalledProcessError:
        return


222
223
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
224

225
226
    git_check_submodules()

227
228
229
230
231
232
233
234
235
236
237
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
238
        install_requires = []
239
        extras_require = {
240
241
242
            "core": [f"transformer_engine_cu12=={__version__}"],
            "core_cu12": [f"transformer_engine_cu12=={__version__}"],
            "core_cu13": [f"transformer_engine_cu13=={__version__}"],
243
244
245
246
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
247
        install_requires, test_requires = setup_requirements()
248
249
250
251
252
253
254
255
256
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

wenjh's avatar
wenjh committed
257
                if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
wenjh's avatar
wenjh committed
258
259
260
261
262
263
264
265
266
267
                    ext_modules.append(
                        setup_pytorch_extension(
                            "transformer_engine/pytorch/csrc",
                            current_file_path / "transformer_engine" / "pytorch" / "csrc",
                            current_file_path / "transformer_engine",
                        )
                    )
                else:
                    ext_modules.append(
                        setup_pytorch_extension(
wenjh's avatar
wenjh committed
268
269
270
                            "transformer_engine_hygon/pytorch/csrc",
                            current_file_path / "transformer_engine_hygon" / "pytorch" / "csrc",
                            current_file_path / "transformer_engine_hygon",
wenjh's avatar
wenjh committed
271
                        )
272
273
274
275
276
277
278
279
280
281
                    )
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
282
                )
283

wenjh's avatar
wenjh committed
284
    if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
wenjh's avatar
wenjh committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        # Configure package
        setuptools.setup(
            name="transformer_engine",
            version=__version__,
            packages=setuptools.find_packages(
                include=[
                    "transformer_engine",
                    "transformer_engine.*",
                    "transformer_engine/build_tools",
                ],
            ),
            extras_require=extras_require,
            description="Transformer acceleration library",
            long_description=long_description,
            long_description_content_type="text/x-rst",
            ext_modules=ext_modules,
            cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
            python_requires=f">={min_python_version_str()}",
            classifiers=["Programming Language :: Python :: 3"],
            install_requires=install_requires,
            license_files=("LICENSE",),
            include_package_data=include_package_data,
            package_data=package_data,
        )
    else:
        # Configure package of hygon backend for TransformerEngine-FL
        setuptools.setup(
wenjh's avatar
wenjh committed
312
            name="transformer_engine_hygon",
wenjh's avatar
wenjh committed
313
314
315
            version=__version__,
            packages=setuptools.find_packages(
                include=[
wenjh's avatar
wenjh committed
316
317
                    "transformer_engine_hygon",
                    "transformer_engine_hygon.*",
wenjh's avatar
wenjh committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
                ],
            ),
            extras_require=extras_require,
            description="Transformer acceleration library for TransformerEngine-FL",
            long_description=long_description,
            long_description_content_type="text/x-rst",
            ext_modules=ext_modules,
            cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
            python_requires=f">={min_python_version_str()}",
            classifiers=["Programming Language :: Python :: 3"],
            install_requires=install_requires,
            license_files=("LICENSE",),
            include_package_data=include_package_data,
            package_data=package_data,
        )