"tests/vscode:/vscode.git/clone" did not exist on "c128dabb63f9e4cf5ae6ae0937596a0dad812712"
build_ext.py 7.74 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

"""Installation script."""

import os
import subprocess
import sys
import sysconfig
import copy
Phuong Nguyen's avatar
Phuong Nguyen committed
12
import time
13
14
15
16
17
18
19
20
21
22
23
24

from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Type

import setuptools

from .utils import (
    cmake_bin,
    debug_build_enabled,
    found_ninja,
    get_frameworks,
25
    nvcc_path,
26
    get_max_jobs_for_parallel_build,
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
)


class CMakeExtension(setuptools.Extension):
    """CMake extension module"""

    def __init__(
        self,
        name: str,
        cmake_path: Path,
        cmake_flags: Optional[List[str]] = None,
    ) -> None:
        super().__init__(name, sources=[])  # No work for base class
        self.cmake_path: Path = cmake_path
        self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags

    def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
        # Make sure paths are str
        _cmake_bin = str(cmake_bin())
        cmake_path = str(self.cmake_path)
        build_dir = str(build_dir)
        install_dir = str(install_dir)

        # CMake configure command
        build_type = "Debug" if debug_build_enabled() else "Release"
        configure_command = [
            _cmake_bin,
            "-S",
            cmake_path,
            "-B",
            build_dir,
            f"-DPython_EXECUTABLE={sys.executable}",
            f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
            f"-DCMAKE_BUILD_TYPE={build_type}",
            f"-DCMAKE_INSTALL_PREFIX={install_dir}",
        ]
        configure_command += self.cmake_flags

        import pybind11
66

67
68
69
70
71
        pybind11_dir = Path(pybind11.__file__).resolve().parent
        pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
        configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")

        # CMake build and install commands
72
73
        build_command = [_cmake_bin, "--build", build_dir, "--verbose"]
        install_command = [_cmake_bin, "--install", build_dir, "--verbose"]
74

75
76
77
78
79
80
81
82
        # Check whether parallel build is restricted
        max_jobs = get_max_jobs_for_parallel_build()
        if found_ninja():
            configure_command.append("-GNinja")
        build_command.append("--parallel")
        if max_jobs > 0:
            build_command.append(str(max_jobs))

83
        # Run CMake commands
Phuong Nguyen's avatar
Phuong Nguyen committed
84
        start_time = time.perf_counter()
85
86
87
88
89
90
91
        for command in [configure_command, build_command, install_command]:
            print(f"Running command {' '.join(command)}")
            try:
                subprocess.run(command, cwd=build_dir, check=True)
            except (CalledProcessError, OSError) as e:
                raise RuntimeError(f"Error when running CMake: {e}")

Phuong Nguyen's avatar
Phuong Nguyen committed
92
93
94
        total_time = time.perf_counter() - start_time
        print(f"Time for build_ext: {total_time:.2f} seconds")

95

96
97
98
def get_build_ext(
    extension_cls: Type[setuptools.Extension], framework_extension_only: bool = False
):
99
100
101
102
103
104
105
106
107
108
109
    class _CMakeBuildExtension(extension_cls):
        """Setuptools command with support for CMake extension modules"""

        def run(self) -> None:
            # Build CMake extensions
            for ext in self.extensions:
                package_path = Path(self.get_ext_fullpath(ext.name))
                install_dir = package_path.resolve().parent
                if isinstance(ext, CMakeExtension):
                    print(f"Building CMake extension {ext.name}")
                    # Set up incremental builds for CMake extensions
110
111
112
113
114
115
                    build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
                    if build_dir:
                        build_dir = Path(build_dir).resolve()
                    else:
                        root_dir = Path(__file__).resolve().parent.parent
                        build_dir = root_dir / "build" / "cmake"
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

                    # Ensure the directory exists
                    build_dir.mkdir(parents=True, exist_ok=True)

                    ext._build_cmake(
                        build_dir=build_dir,
                        install_dir=install_dir,
                    )

            # Build non-CMake extensions as usual
            all_extensions = self.extensions
            self.extensions = [
                ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
            ]
            super().run()
            self.extensions = all_extensions

            # Ensure that binaries are not in global package space.
134
135
            lib_dir = (
                "wheel_lib"
136
                if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
137
138
139
                else ""
            )
            target_dir = install_dir / "transformer_engine" / lib_dir
140
141
142
143
144
145
146
            target_dir.mkdir(exist_ok=True, parents=True)

            for ext in Path(self.build_lib).glob("*.so"):
                self.copy_file(ext, target_dir)
                os.remove(ext)

        def build_extensions(self):
147
148
            # For core lib + JAX install, fix build_ext from pybind11.setup_helpers
            # to handle CUDA files correctly.
149
            if "pytorch" not in get_frameworks():
150
151
152
153
                # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
                # extra_compile_args is a dict.
                for ext in self.extensions:
                    if isinstance(ext.extra_compile_args, dict):
154
                        for target in ["cxx", "nvcc"]:
155
156
157
158
159
                            if target not in ext.extra_compile_args.keys():
                                ext.extra_compile_args[target] = []

                # Define new _compile method that redirects to NVCC for .cu and .cuh files.
                original_compile_fn = self.compiler._compile
160
161
                if not framework_extension_only:
                    self.compiler.src_extensions += [".cu", ".cuh"]
162

163
164
165
166
167
168
169
                def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
                    # Copy before we make any modifications.
                    cflags = copy.deepcopy(extra_postargs)
                    original_compiler = self.compiler.compiler_so
                    try:
                        original_compiler = self.compiler.compiler_so

170
171
172
173
174
                        if (
                            os.path.splitext(src)[1] in [".cu", ".cuh"]
                            and not framework_extension_only
                        ):
                            nvcc_bin = nvcc_path()
175
                            self.compiler.set_executable("compiler_so", str(nvcc_bin))
176
                            if isinstance(cflags, dict):
177
                                cflags = cflags["nvcc"]
178
179

                            # Add -fPIC if not already specified
180
181
                            if not any("-fPIC" in flag for flag in cflags):
                                cflags.extend(["--compiler-options", "'-fPIC'"])
182
183

                            # Forward unknown options
184
185
                            if not any("--forward-unknown-opts" in flag for flag in cflags):
                                cflags.append("--forward-unknown-opts")
186
                        elif isinstance(cflags, dict):
187
                            cflags = cflags["cxx"]
188
189

                        # Append -std=c++17 if not already in flags
190
191
                        if not any(flag.startswith("-std=") for flag in cflags):
                            cflags.append("-std=c++17")
192
193
194
195
196

                        return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)

                    finally:
                        # Put the original compiler back in place.
197
                        self.compiler.set_executable("compiler_so", original_compiler)
198
199
200
201
202
203

                self.compiler._compile = _compile_fn

            super().build_extensions()

    return _CMakeBuildExtension