pytorch.py 3.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""PyTorch related extensions."""
import os
from pathlib import Path

import setuptools

from .utils import (
    all_files_in_dir,
13
    cuda_archs,
14
    cuda_path,
15
    cuda_version,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
)


def setup_pytorch_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup CUDA extension for PyTorch support"""

    # Source files
    csrc_source_files = Path(csrc_source_files)
    extensions_dir = csrc_source_files / "extensions"
    sources = [
        csrc_source_files / "common.cu",
        csrc_source_files / "ts_fp8_op.cpp",
32
33
34
        csrc_source_files / "userbuffers" / "ipcsocket.cc",
        csrc_source_files / "userbuffers" / "userbuffers.cu",
        csrc_source_files / "userbuffers" / "userbuffers-host.cpp",
35
36
37
38
39
40
41
42
43
    ] + all_files_in_dir(extensions_dir)

    # Header files
    include_dirs = [
        common_header_files,
        common_header_files / "common",
        common_header_files / "common" / "include",
        csrc_header_files,
    ]
44

45
    # Compiler flags
46
47
48
49
    cxx_flags = [
        "-O3",
        "-fvisibility=hidden",
    ]
50
51
52
53
54
55
56
57
58
59
60
61
62
    nvcc_flags = [
        "-O3",
        "-U__CUDA_NO_HALF_OPERATORS__",
        "-U__CUDA_NO_HALF_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
        "--expt-relaxed-constexpr",
        "--expt-extended-lambda",
        "--use_fast_math",
    ]

63
64
65
66
67
    cuda_architectures = cuda_archs()

    if "70" in cuda_architectures:
        nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])

68
69
70
71
72
73
    # Version-dependent CUDA options
    try:
        version = cuda_version()
    except FileNotFoundError:
        print("Could not determine CUDA Toolkit version")
    else:
74
75
76
77
78
79
80
81
        if version < (12, 0):
            raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
        nvcc_flags.extend(
            (
                "--threads",
                os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
            )
        )
82

83
84
85
86
87
        if "80" in cuda_architectures:
            nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
        if "90" in cuda_architectures:
            nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])

88
89
90
    # Libraries
    library_dirs = []
    libraries = []
91
    if os.getenv("NVTE_UB_WITH_MPI"):
92
93
        assert (
            os.getenv("MPI_HOME") is not None
94
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
95
96
        mpi_home = Path(os.getenv("MPI_HOME"))
        include_dirs.append(mpi_home / "include")
97
98
        cxx_flags.append("-DNVTE_UB_WITH_MPI")
        nvcc_flags.append("-DNVTE_UB_WITH_MPI")
99
100
        library_dirs.append(mpi_home / "lib")
        libraries.append("mpi")
101
102
103
104
105
106
107
108

    # Construct PyTorch CUDA extension
    sources = [str(path) for path in sources]
    include_dirs = [str(path) for path in include_dirs]
    from torch.utils.cpp_extension import CUDAExtension

    return CUDAExtension(
        name="transformer_engine_torch",
109
110
        sources=[str(src) for src in sources],
        include_dirs=[str(inc) for inc in include_dirs],
111
112
113
114
        extra_compile_args={
            "cxx": cxx_flags,
            "nvcc": nvcc_flags,
        },
115
116
        libraries=[str(lib) for lib in libraries],
        library_dirs=[str(lib_dir) for lib_dir in library_dirs],
117
    )