pytorch.py 3.22 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

"""PyTorch related extensions."""
import os
from pathlib import Path

import setuptools

11
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
12
13
14
15
from typing import List


def install_requirements() -> List[str]:
16
    """Install dependencies for TE/PyTorch extensions."""
17
18
19
20
21
22
23
24
25
26
    return [
        "torch>=2.1",
        "einops",
        "onnxscript",
        "onnx",
        "packaging",
        "pydantic",
        "nvdlfw-inspect",
        "triton",
    ]
27
28
29


def test_requirements() -> List[str]:
30
31
32
33
34
35
36
37
38
    """Test dependencies for TE/PyTorch extensions."""
    return [
        "numpy",
        "torchvision",
        "transformers",
        "torchao==0.13",
        "onnxruntime",
        "onnxruntime_extensions",
    ]
39
40
41
42
43
44
45
46
47
48


def setup_pytorch_extension(
    csrc_source_files,
    csrc_header_files,
    common_header_files,
) -> setuptools.Extension:
    """Setup CUDA extension for PyTorch support"""

    # Source files
49
    sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
50
51

    # Header files
52
53
54
55
56
57
58
59
60
    include_dirs = get_cuda_include_dirs()
    include_dirs.extend(
        [
            common_header_files,
            common_header_files / "common",
            common_header_files / "common" / "include",
            csrc_header_files,
        ]
    )
61

62
    # Compiler flags
63
64
65
66
67
68
    cxx_flags = ["-O3", "-fvisibility=hidden"]
    if debug_build_enabled():
        cxx_flags.append("-g")
        cxx_flags.append("-UNDEBUG")
    else:
        cxx_flags.append("-g0")
69

70
71
72
73
    # Version-dependent CUDA options
    try:
        version = cuda_version()
    except FileNotFoundError:
74
        print("Could not determine CUDA version")
75
    else:
76
77
        if version < (12, 0):
            raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
78

79
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
80
81
        assert (
            os.getenv("MPI_HOME") is not None
82
83
84
        ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
        mpi_path = Path(os.getenv("MPI_HOME"))
        include_dirs.append(mpi_path / "include")
85
        cxx_flags.append("-DNVTE_UB_WITH_MPI")
86

87
88
89
90
91
92
93
94
95
96
97
98
    library_dirs = []
    libraries = []
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
        include_dirs.append(nvshmem_home / "include")
        library_dirs.append(nvshmem_home / "lib")
        libraries.append("nvshmem_host")
        cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

99
100
101
    # Construct PyTorch CUDA extension
    sources = [str(path) for path in sources]
    include_dirs = [str(path) for path in include_dirs]
102
    from torch.utils.cpp_extension import CppExtension
103

104
    return CppExtension(
105
        name="transformer_engine_torch",
106
107
        sources=[str(src) for src in sources],
        include_dirs=[str(inc) for inc in include_dirs],
108
        extra_compile_args={"cxx": cxx_flags},
109
110
        libraries=[str(lib) for lib in libraries],
        library_dirs=[str(lib_dir) for lib_dir in library_dirs],
111
    )