build_ext.py 8.06 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
#
# See LICENSE for license information.

"""Installation script."""

import ctypes
import os
import subprocess
import sys
import sysconfig
import copy
Phuong Nguyen's avatar
Phuong Nguyen committed
13
import time
14
15
16
17
18
19
20
21

from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Type

import setuptools

from .utils import (
yuguo's avatar
yuguo committed
22
23
    rocm_build,
    rocm_path,
24
25
26
27
28
    cmake_bin,
    debug_build_enabled,
    found_ninja,
    get_frameworks,
    cuda_path,
29
    get_max_jobs_for_parallel_build,
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
)


class CMakeExtension(setuptools.Extension):
    """CMake extension module"""

    def __init__(
        self,
        name: str,
        cmake_path: Path,
        cmake_flags: Optional[List[str]] = None,
    ) -> None:
        super().__init__(name, sources=[])  # No work for base class
        self.cmake_path: Path = cmake_path
        self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags

    def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
        # Make sure paths are str
        _cmake_bin = str(cmake_bin())
        cmake_path = str(self.cmake_path)
        build_dir = str(build_dir)
        install_dir = str(install_dir)

        # CMake configure command
        build_type = "Debug" if debug_build_enabled() else "Release"
        configure_command = [
            _cmake_bin,
            "-S",
            cmake_path,
            "-B",
            build_dir,
            f"-DPython_EXECUTABLE={sys.executable}",
            f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
            f"-DCMAKE_BUILD_TYPE={build_type}",
            f"-DCMAKE_INSTALL_PREFIX={install_dir}",
        ]
        configure_command += self.cmake_flags

        import pybind11
69

70
71
72
73
74
        pybind11_dir = Path(pybind11.__file__).resolve().parent
        pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
        configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")

        # CMake build and install commands
75
76
        build_command = [_cmake_bin, "--build", build_dir, "--verbose"]
        install_command = [_cmake_bin, "--install", build_dir, "--verbose"]
77

78
79
80
81
82
83
84
85
        # Check whether parallel build is restricted
        max_jobs = get_max_jobs_for_parallel_build()
        if found_ninja():
            configure_command.append("-GNinja")
        build_command.append("--parallel")
        if max_jobs > 0:
            build_command.append(str(max_jobs))

86
        # Run CMake commands
Phuong Nguyen's avatar
Phuong Nguyen committed
87
        start_time = time.perf_counter()
88
89
90
91
92
93
94
        for command in [configure_command, build_command, install_command]:
            print(f"Running command {' '.join(command)}")
            try:
                subprocess.run(command, cwd=build_dir, check=True)
            except (CalledProcessError, OSError) as e:
                raise RuntimeError(f"Error when running CMake: {e}")

Phuong Nguyen's avatar
Phuong Nguyen committed
95
96
97
        total_time = time.perf_counter() - start_time
        print(f"Time for build_ext: {total_time:.2f} seconds")

98

99
def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False):
100
101
102
103
104
105
106
107
108
109
110
    class _CMakeBuildExtension(extension_cls):
        """Setuptools command with support for CMake extension modules"""

        def run(self) -> None:
            # Build CMake extensions
            for ext in self.extensions:
                package_path = Path(self.get_ext_fullpath(ext.name))
                install_dir = package_path.resolve().parent
                if isinstance(ext, CMakeExtension):
                    print(f"Building CMake extension {ext.name}")
                    # Set up incremental builds for CMake extensions
111
112
113
114
115
116
                    build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
                    if build_dir:
                        build_dir = Path(build_dir).resolve()
                    else:
                        root_dir = Path(__file__).resolve().parent.parent
                        build_dir = root_dir / "build" / "cmake"
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

                    # Ensure the directory exists
                    build_dir.mkdir(parents=True, exist_ok=True)

                    ext._build_cmake(
                        build_dir=build_dir,
                        install_dir=install_dir,
                    )

            # Build non-CMake extensions as usual
            all_extensions = self.extensions
            self.extensions = [
                ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
            ]
            super().run()
            self.extensions = all_extensions

            # Ensure that binaries are not in global package space.
135
136
137
138
139
140
            lib_dir = (
                "wheel_lib"
                if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib
                else ""
            )
            target_dir = install_dir / "transformer_engine" / lib_dir
141
142
143
144
145
146
147
            target_dir.mkdir(exist_ok=True, parents=True)

            for ext in Path(self.build_lib).glob("*.so"):
                self.copy_file(ext, target_dir)
                os.remove(ext)

        def build_extensions(self):
148
            # BuildExtensions from PyTorch already handle CUDA files correctly
149
            # so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
150
            if "pytorch" not in get_frameworks():
151
152
153
154
                # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
                # extra_compile_args is a dict.
                for ext in self.extensions:
                    if isinstance(ext.extra_compile_args, dict):
155
                        for target in ["cxx", "nvcc"]:
156
157
158
159
                            if target not in ext.extra_compile_args.keys():
                                ext.extra_compile_args[target] = []

                # Define new _compile method that redirects to NVCC for .cu and .cuh files.
yuguo's avatar
yuguo committed
160
                # Also redirect .hip files to HIPCC
161
                original_compile_fn = self.compiler._compile
yuguo's avatar
yuguo committed
162
                self.compiler.src_extensions += [".cu", ".cuh", ".hip"]
163

164
165
166
167
168
                def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
                    # Copy before we make any modifications.
                    cflags = copy.deepcopy(extra_postargs)
                    original_compiler = self.compiler.compiler_so
                    try:
yuguo's avatar
yuguo committed
169
170
171
172
                        if rocm_build():
                            _, nvcc_bin = rocm_path()
                        else:
                            _, nvcc_bin = cuda_path()
173
174
                        original_compiler = self.compiler.compiler_so

yuguo's avatar
yuguo committed
175
                        if os.path.splitext(src)[1] in [".cu", ".cuh", ".hip"]:
176
                            self.compiler.set_executable("compiler_so", str(nvcc_bin))
177
                            if isinstance(cflags, dict):
178
                                cflags = cflags["nvcc"]
179
180

                            # Add -fPIC if not already specified
181
                            if not any("-fPIC" in flag for flag in cflags):
yuguo's avatar
yuguo committed
182
183
184
185
186
187
188
189
190
                                if rocm_build():
                                    cflags.append("-fPIC")
                                else:
                                    cflags.extend(["--compiler-options", "'-fPIC'"])

                            if not rocm_build():
                                # Forward unknown options
                                if not any("--forward-unknown-opts" in flag for flag in cflags):
                                    cflags.append("--forward-unknown-opts")
191
192

                        elif isinstance(cflags, dict):
193
                            cflags = cflags["cxx"]
194
195

                        # Append -std=c++17 if not already in flags
196
197
                        if not any(flag.startswith("-std=") for flag in cflags):
                            cflags.append("-std=c++17")
198
199
200
201
202

                        return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)

                    finally:
                        # Put the original compiler back in place.
203
                        self.compiler.set_executable("compiler_so", original_compiler)
204
205
206
207
208
209

                self.compiler._compile = _compile_fn

            super().build_extensions()

    return _CMakeBuildExtension