build_ext.py 9.05 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

"""Installation script."""

import os
import subprocess
import sys
import sysconfig
import copy
Phuong Nguyen's avatar
Phuong Nguyen committed
12
import time
13
14
15
16
17
18
19
20

from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Type

import setuptools

from .utils import (
yuguo's avatar
yuguo committed
21
22
    rocm_build,
    rocm_path,
23
24
25
26
    cmake_bin,
    debug_build_enabled,
    found_ninja,
    get_frameworks,
27
    nvcc_path,
28
    get_max_jobs_for_parallel_build,
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
)


class CMakeExtension(setuptools.Extension):
    """CMake extension module"""

    def __init__(
        self,
        name: str,
        cmake_path: Path,
        cmake_flags: Optional[List[str]] = None,
    ) -> None:
        super().__init__(name, sources=[])  # No work for base class
        self.cmake_path: Path = cmake_path
        self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags

    def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
        # Make sure paths are str
        _cmake_bin = str(cmake_bin())
        cmake_path = str(self.cmake_path)
        build_dir = str(build_dir)
        install_dir = str(install_dir)

        # CMake configure command
        build_type = "Debug" if debug_build_enabled() else "Release"
        configure_command = [
            _cmake_bin,
            "-S",
            cmake_path,
            "-B",
            build_dir,
            f"-DPython_EXECUTABLE={sys.executable}",
            f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
62
            f"-DPython_SITEARCH={sysconfig.get_path('platlib')}",
63
64
65
            f"-DCMAKE_BUILD_TYPE={build_type}",
            f"-DCMAKE_INSTALL_PREFIX={install_dir}",
        ]
66
67
68
69
70
71
        if bool(int(os.getenv("NVTE_USE_CCACHE", "0"))):
            ccache_bin = os.getenv("NVTE_CCACHE_BIN", "ccache")
            configure_command += [
                f"-DCMAKE_CXX_COMPILER_LAUNCHER={ccache_bin}",
                f"-DCMAKE_CUDA_COMPILER_LAUNCHER={ccache_bin}",
            ]
72
73
74
        configure_command += self.cmake_flags

        import pybind11
75

76
77
78
79
80
        pybind11_dir = Path(pybind11.__file__).resolve().parent
        pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
        configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")

        # CMake build and install commands
81
82
        build_command = [_cmake_bin, "--build", build_dir, "--verbose"]
        install_command = [_cmake_bin, "--install", build_dir, "--verbose"]
83

84
85
86
87
88
89
90
91
        # Check whether parallel build is restricted
        max_jobs = get_max_jobs_for_parallel_build()
        if found_ninja():
            configure_command.append("-GNinja")
        build_command.append("--parallel")
        if max_jobs > 0:
            build_command.append(str(max_jobs))

92
        # Run CMake commands
Phuong Nguyen's avatar
Phuong Nguyen committed
93
        start_time = time.perf_counter()
94
95
96
97
98
99
100
        for command in [configure_command, build_command, install_command]:
            print(f"Running command {' '.join(command)}")
            try:
                subprocess.run(command, cwd=build_dir, check=True)
            except (CalledProcessError, OSError) as e:
                raise RuntimeError(f"Error when running CMake: {e}")

Phuong Nguyen's avatar
Phuong Nguyen committed
101
102
103
        total_time = time.perf_counter() - start_time
        print(f"Time for build_ext: {total_time:.2f} seconds")

104

105
106
107
def get_build_ext(
    extension_cls: Type[setuptools.Extension], framework_extension_only: bool = False
):
108
109
110
111
112
113
114
115
116
117
118
    class _CMakeBuildExtension(extension_cls):
        """Setuptools command with support for CMake extension modules"""

        def run(self) -> None:
            # Build CMake extensions
            for ext in self.extensions:
                package_path = Path(self.get_ext_fullpath(ext.name))
                install_dir = package_path.resolve().parent
                if isinstance(ext, CMakeExtension):
                    print(f"Building CMake extension {ext.name}")
                    # Set up incremental builds for CMake extensions
119
120
121
122
123
124
                    build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
                    if build_dir:
                        build_dir = Path(build_dir).resolve()
                    else:
                        root_dir = Path(__file__).resolve().parent.parent
                        build_dir = root_dir / "build" / "cmake"
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

                    # Ensure the directory exists
                    build_dir.mkdir(parents=True, exist_ok=True)

                    ext._build_cmake(
                        build_dir=build_dir,
                        install_dir=install_dir,
                    )

            # Build non-CMake extensions as usual
            all_extensions = self.extensions
            self.extensions = [
                ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
            ]
            super().run()
            self.extensions = all_extensions

142
143
            # Ensure that shared objects files for source and PyPI installations live
            # in separate directories to avoid conflicts during install and runtime.
144
145
            lib_dir = (
                "wheel_lib"
146
                if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
147
148
                else ""
            )
149

150
151
152
153
            # Ensure that binaries are not in global package space.
            # For editable/inplace builds this is not a concern as
            # the SOs will be in a local directory anyway.
            if not self.inplace:
wenjh's avatar
wenjh committed
154
                if not bool(int(os.getenv("TEFL_HYGON_BACKEND", "0"))):
wenjh's avatar
wenjh committed
155
156
                    target_dir = install_dir / "transformer_engine" / lib_dir
                else:
wenjh's avatar
wenjh committed
157
                    target_dir = install_dir / "transformer_engine_hygon" / lib_dir
158
159
160
161
162
                target_dir.mkdir(exist_ok=True, parents=True)

                for ext in Path(self.build_lib).glob("*.so"):
                    self.copy_file(ext, target_dir)
                    os.remove(ext)
163
164

        def build_extensions(self):
165
166
            # For core lib + JAX install, fix build_ext from pybind11.setup_helpers
            # to handle CUDA files correctly.
167
            if "pytorch" not in get_frameworks():
168
169
170
171
                # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
                # extra_compile_args is a dict.
                for ext in self.extensions:
                    if isinstance(ext.extra_compile_args, dict):
172
                        for target in ["cxx", "nvcc"]:
173
174
175
176
                            if target not in ext.extra_compile_args.keys():
                                ext.extra_compile_args[target] = []

                # Define new _compile method that redirects to NVCC for .cu and .cuh files.
yuguo's avatar
yuguo committed
177
                # Also redirect .hip files to HIPCC
178
                original_compile_fn = self.compiler._compile
179
                if not framework_extension_only:
180
                    self.compiler.src_extensions += [".cu", ".cuh", ".hip"]
181

182
183
184
185
186
187
188
                def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
                    # Copy before we make any modifications.
                    cflags = copy.deepcopy(extra_postargs)
                    original_compiler = self.compiler.compiler_so
                    try:
                        original_compiler = self.compiler.compiler_so

189
                        if (
190
                            os.path.splitext(src)[1] in [".cu", ".cuh", ".hip"]
191
192
                            and not framework_extension_only
                        ):
193
194
195
196
                            if rocm_build():
                                _, nvcc_bin = rocm_path()
                            else:
                                nvcc_bin = nvcc_path()
197
                            self.compiler.set_executable("compiler_so", str(nvcc_bin))
198
                            if isinstance(cflags, dict):
199
                                cflags = cflags["nvcc"]
200
201

                            # Add -fPIC if not already specified
202
                            if not any("-fPIC" in flag for flag in cflags):
yuguo's avatar
yuguo committed
203
204
205
206
207
208
209
210
211
                                if rocm_build():
                                    cflags.append("-fPIC")
                                else:
                                    cflags.extend(["--compiler-options", "'-fPIC'"])

                            if not rocm_build():
                                # Forward unknown options
                                if not any("--forward-unknown-opts" in flag for flag in cflags):
                                    cflags.append("--forward-unknown-opts")
212
                        elif isinstance(cflags, dict):
213
                            cflags = cflags["cxx"]
214
215

                        # Append -std=c++17 if not already in flags
216
217
                        if not any(flag.startswith("-std=") for flag in cflags):
                            cflags.append("-std=c++17")
218
219
220
221
222

                        return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)

                    finally:
                        # Put the original compiler back in place.
223
                        self.compiler.set_executable("compiler_so", original_compiler)
224
225
226
227
228
229

                self.compiler._compile = _compile_fn

            super().build_extensions()

    return _CMakeBuildExtension