pytorch.py 5.07 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

"""PyTorch related extensions."""
import os
from pathlib import Path

import setuptools

from .utils import (
yuguo's avatar
yuguo committed
12
13
    rocm_build,
    hipify,
14
    all_files_in_dir,
15
16
    cuda_archs,
    cuda_version,
17
18
19
20
21
22
23
24
25
26
27
28
29
30
)


def setup_pytorch_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup CUDA extension for PyTorch support"""

    # Source files
    csrc_source_files = Path(csrc_source_files)
    extensions_dir = csrc_source_files / "extensions"
    sources = [
31
        csrc_source_files / "common.cpp",
32
33
34
35
36
37
38
39
40
    ] + all_files_in_dir(extensions_dir)

    # Header files
    include_dirs = [
        common_header_files,
        common_header_files / "common",
        common_header_files / "common" / "include",
        csrc_header_files,
    ]
41

yuguo's avatar
yuguo committed
42
43
44
45
46
    if rocm_build():
        current_file_path = Path(__file__).parent.resolve()
        base_dir = current_file_path.parent
        sources = hipify(base_dir, csrc_source_files, sources, include_dirs)

47
    # Compiler flags
48
49
50
51
    cxx_flags = [
        "-O3",
        "-fvisibility=hidden",
    ]
yuguo's avatar
yuguo committed
52
53
54
55
56
57
58
59
60
    if rocm_build():
        nvcc_flags = [
            "-O3",
            "-U__HIP_NO_HALF_OPERATORS__",
            "-U__HIP_NO_HALF_CONVERSIONS__",
            "-U__HIP_NO_BFLOAT16_OPERATORS__",
            "-U__HIP_NO_BFLOAT16_CONVERSIONS__",
            "-U__HIP_NO_BFLOAT162_OPERATORS__",
            "-U__HIP_NO_BFLOAT162_CONVERSIONS__",
yuguo's avatar
yuguo committed
61
            "-DUSE_ROCM",
yuguo's avatar
yuguo committed
62
        ]
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            nvcc_flags.extend(
                [
                    "-Wno-unused-result",
                    "-Wno-unused-function",
                    "-Wno-unused-private-field",
                    "-Wno-unused-variable",
                ]
            )
            cxx_flags.extend(
                [
                    "-Wno-unused-result",
                    "-Wno-unused-function",
                    "-Wno-unused-private-field",
                    "-Wno-unused-variable",
                ]
            )

        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "0"))):
            nvcc_flags.append("-Wno-return-type")
            cxx_flags.append("-Wno-return-type")

        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "0"))):
            nvcc_flags.append("-Wno-sign-compare")
            cxx_flags.append("-Wno-sign-compare")
            
yuguo's avatar
yuguo committed
89
90
91
92
93
94
95
96
97
98
99
100
101
    else:
        nvcc_flags = [
            "-O3",
            "-U__CUDA_NO_HALF_OPERATORS__",
            "-U__CUDA_NO_HALF_CONVERSIONS__",
            "-U__CUDA_NO_BFLOAT16_OPERATORS__",
            "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
            "-U__CUDA_NO_BFLOAT162_OPERATORS__",
            "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
            "--expt-relaxed-constexpr",
            "--expt-extended-lambda",
            "--use_fast_math",
        ]
102
    # Version-dependent CUDA options
yuguo's avatar
yuguo committed
103
104
105
    if rocm_build():
        ##TODO: Figure out which hipcc version starts to support this parallel compilation
        nvcc_flags.extend(["-parallel-jobs=4"])
106
    else:
yuguo's avatar
yuguo committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        cuda_architectures = cuda_archs()
        if "70" in cuda_architectures:
            nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
        try:
            version = cuda_version()
        except FileNotFoundError:
            print("Could not determine CUDA Toolkit version")
        else:
            if version < (12, 0):
                raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
            nvcc_flags.extend(
                (
                    "--threads",
                    os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
                )
122
            )
yuguo's avatar
yuguo committed
123
124
125
126
127
    
            for arch in cuda_architectures.split(";"):
                if arch == "70":
                    continue  # Already handled
                nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
128

yuguo's avatar
yuguo committed
129
130
131
    # Libraries
    library_dirs = []
    libraries = []
132
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
133
134
        assert (
            os.getenv("MPI_HOME") is not None
135
136
137
        ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
        mpi_path = Path(os.getenv("MPI_HOME"))
        include_dirs.append(mpi_path / "include")
138
139
        cxx_flags.append("-DNVTE_UB_WITH_MPI")
        nvcc_flags.append("-DNVTE_UB_WITH_MPI")
yuguo's avatar
yuguo committed
140
141
        library_dirs.append(mpi_path / "lib")
        libraries.append("mpi")
142
143
144
145
146
147
148
149

    # Construct PyTorch CUDA extension
    sources = [str(path) for path in sources]
    include_dirs = [str(path) for path in include_dirs]
    from torch.utils.cpp_extension import CUDAExtension

    return CUDAExtension(
        name="transformer_engine_torch",
150
151
        sources=[str(src) for src in sources],
        include_dirs=[str(inc) for inc in include_dirs],
152
153
154
155
        extra_compile_args={
            "cxx": cxx_flags,
            "nvcc": nvcc_flags,
        },
yuguo's avatar
yuguo committed
156
157
        libraries=[str(lib) for lib in libraries],
        library_dirs=[str(lib_dir) for lib_dir in library_dirs],
158
    )