pytorch.py 6.18 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""PyTorch related extensions."""
import os
from pathlib import Path

import setuptools

11
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled, rocm_build, hipify
12
13
14
15
from typing import List


def install_requirements() -> List[str]:
16
    """Install dependencies for TE/PyTorch extensions."""
17
    reqs = ["torch>=2.1", "einops"]
yuguo's avatar
yuguo committed
18
19
20
21
    # reqs.append(
    #     "nvdlfw-inspect @"
    #     " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
    # )
22
23
24
    reqs.extend(
        [
            "torch>=2.1",
25
26
            # "onnx",
            # "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
27
28
        ]
    )
29
30
31
32
33
34
    return reqs


def test_requirements() -> List[str]:
    """Test dependencies for TE/JAX extensions."""
    return ["numpy", "torchvision", "transformers"]
35
36
37
38
39
40
41
42
43
44


def setup_pytorch_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup CUDA extension for PyTorch support"""

    # Source files
45
    sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
46
47

    # Header files
48
49
    if rocm_build():
        include_dirs = [
50
51
52
53
54
            common_header_files,
            common_header_files / "common",
            common_header_files / "common" / "include",
            csrc_header_files,
        ]
55
56
57
58
59
60
61
62
63
64
    else:
        include_dirs = get_cuda_include_dirs()
        include_dirs.extend(
            [
                common_header_files,
                common_header_files / "common",
                common_header_files / "common" / "include",
                csrc_header_files,
            ]
        )
65

yuguo's avatar
yuguo committed
66
67
68
69
70
    if rocm_build():
        current_file_path = Path(__file__).parent.resolve()
        base_dir = current_file_path.parent
        sources = hipify(base_dir, csrc_source_files, sources, include_dirs)

71
    # Compiler flags
72
73
74
75
    cxx_flags = [
        "-O3",
        "-fvisibility=hidden",
    ]
yuguo's avatar
yuguo committed
76
77
78
79
80
81
82
83
84
    if rocm_build():
        nvcc_flags = [
            "-O3",
            "-U__HIP_NO_HALF_OPERATORS__",
            "-U__HIP_NO_HALF_CONVERSIONS__",
            "-U__HIP_NO_BFLOAT16_OPERATORS__",
            "-U__HIP_NO_BFLOAT16_CONVERSIONS__",
            "-U__HIP_NO_BFLOAT162_OPERATORS__",
            "-U__HIP_NO_BFLOAT162_CONVERSIONS__",
yuguo's avatar
yuguo committed
85
            "-DUSE_ROCM",
yuguo's avatar
yuguo committed
86
        ]
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            nvcc_flags.extend(
                [
                    "-Wno-unused-result",
                    "-Wno-unused-function",
                    "-Wno-unused-private-field",
                    "-Wno-unused-variable",
                ]
            )
            cxx_flags.extend(
                [
                    "-Wno-unused-result",
                    "-Wno-unused-function",
                    "-Wno-unused-private-field",
                    "-Wno-unused-variable",
                ]
            )

105
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
106
107
108
            nvcc_flags.append("-Wno-return-type")
            cxx_flags.append("-Wno-return-type")

109
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
110
111
112
            nvcc_flags.append("-Wno-sign-compare")
            cxx_flags.append("-Wno-sign-compare")
            
yuguo's avatar
yuguo committed
113
    else:
114
115
        pass

116
117
118
    if debug_build_enabled():
        cxx_flags.append("-g")
        cxx_flags.append("-UNDEBUG")
119
120
121
        if rocm_build():
            nvcc_flags.append("-g")
            nvcc_flags.append("-UNDEBUG")
122

123
    # Version-dependent CUDA options
yuguo's avatar
yuguo committed
124
125
126
    if rocm_build():
        ##TODO: Figure out which hipcc version starts to support this parallel compilation
        nvcc_flags.extend(["-parallel-jobs=4"])
127
    else:
yuguo's avatar
yuguo committed
128
129
130
        try:
            version = cuda_version()
        except FileNotFoundError:
131
            print("Could not determine CUDA version")
yuguo's avatar
yuguo committed
132
133
134
        else:
            if version < (12, 0):
                raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
135

yuguo's avatar
yuguo committed
136
137
138
    # Libraries
    library_dirs = []
    libraries = []
139
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
140
141
        assert (
            os.getenv("MPI_HOME") is not None
142
143
144
        ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
        mpi_path = Path(os.getenv("MPI_HOME"))
        include_dirs.append(mpi_path / "include")
145
        cxx_flags.append("-DNVTE_UB_WITH_MPI")
146
147
148
149
        if rocm_build():
            nvcc_flags.append("-DNVTE_UB_WITH_MPI")
            library_dirs.append(mpi_path / "lib")
            libraries.append("mpi")
150

151
152
153
154
155
156
157
158
159
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
        include_dirs.append(nvshmem_home / "include")
        library_dirs.append(nvshmem_home / "lib")
        libraries.append("nvshmem_host")
        cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
160
161
        if rocm_build():
            nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
162

163
164
165
    # Construct PyTorch CUDA extension
    sources = [str(path) for path in sources]
    include_dirs = [str(path) for path in include_dirs]
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    from torch.utils.cpp_extension import CppExtension, CUDAExtension

    if rocm_build():
        return CUDAExtension(
            name="transformer_engine_torch",
            sources=[str(src) for src in sources],
            include_dirs=[str(inc) for inc in include_dirs],
            extra_compile_args={
                "cxx": cxx_flags,
                "nvcc": nvcc_flags,
            },
            libraries=[str(lib) for lib in libraries],
            library_dirs=[str(lib_dir) for lib_dir in library_dirs],
        )
    else:
        return CppExtension(
            name="transformer_engine_torch",
            sources=[str(src) for src in sources],
            include_dirs=[str(inc) for inc in include_dirs],
            extra_compile_args={"cxx": cxx_flags},
            libraries=[str(lib) for lib in libraries],
            library_dirs=[str(lib_dir) for lib_dir in library_dirs],
        )