setup.py 4.26 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
Tim Moon's avatar
Tim Moon committed
8
from pathlib import Path
9
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
10

Tim Moon's avatar
Tim Moon committed
11
import setuptools
Przemek Tredak's avatar
Przemek Tredak committed
12

13
14
15
16
17
18
19
20
21
22
23
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.utils import (
    found_cmake,
    found_ninja,
    found_pybind11,
    remove_dups,
    userbuffers_enabled,
    get_frameworks,
    install_and_import,
)
from build_tools.te_version import te_version
24

Przemek Tredak's avatar
Przemek Tredak committed
25

26
27
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
28
29


30
31
32
33
34
35
36
37
from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
    from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    install_and_import('pybind11')
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
38
39


40
CMakeBuildExtension = get_build_ext(BuildExtension)
41

Tim Moon's avatar
Tim Moon committed
42

43
44
def setup_common_extension() -> CMakeExtension:
    """Setup CMake extension for common library
Tim Moon's avatar
Tim Moon committed
45

46
    Also builds JAX or userbuffers support if needed.
Tim Moon's avatar
Tim Moon committed
47
48

    """
49
50
51
    cmake_flags = []
    if userbuffers_enabled():
        cmake_flags.append("-DNVTE_WITH_USERBUFFERS=ON")
Tim Moon's avatar
Tim Moon committed
52

53
54
55
56
57
58
    # Project directory root
    root_path = Path(__file__).resolve().parent
    return CMakeExtension(
        name="transformer_engine",
        cmake_path=root_path / Path("transformer_engine"),
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
59
60
61
62
63
64
65
66
67
68
69
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
70
71
72
    install_reqs: List[str] = [
        "pydantic",
        "importlib-metadata>=1.0; python_version<'3.8'",
73
        "packaging",
74
    ]
Tim Moon's avatar
Tim Moon committed
75
76
77
78
    test_reqs: List[str] = ["pytest"]

    # Requirements that may be installed outside of Python
    if not found_cmake():
79
        setup_reqs.append("cmake>=3.18")
Tim Moon's avatar
Tim Moon committed
80
    if not found_ninja():
81
82
83
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
84

85
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
86

87

88
if __name__ == "__main__":
Tim Moon's avatar
Tim Moon committed
89
90
    # Dependencies
    setup_requires, install_requires, test_requires = setup_requirements()
91

92
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
93

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    ext_modules = [setup_common_extension()]
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
            from build_tools.pytorch import setup_pytorch_extension
            ext_modules.append(
                setup_pytorch_extension(
                    "transformer_engine/pytorch/csrc",
                    current_file_path / "transformer_engine" / "pytorch" / "csrc",
                    current_file_path / "transformer_engine"))
        if "jax" in frameworks:
            from build_tools.jax import setup_jax_extension
            ext_modules.append(
                setup_jax_extension(
                    "transformer_engine/jax/csrc",
                    current_file_path / "transformer_engine" / "jax" / "csrc",
                    current_file_path / "transformer_engine"))
        if "paddle" in frameworks:
            from build_tools.paddle import setup_paddle_extension
            ext_modules.append(
                setup_paddle_extension(
                    "transformer_engine/paddle/csrc",
                    current_file_path / "transformer_engine" / "paddle" / "csrc",
                    current_file_path / "transformer_engine"))
117

Tim Moon's avatar
Tim Moon committed
118
119
120
    # Configure package
    setuptools.setup(
        name="transformer_engine",
121
122
123
124
125
126
127
128
129
        version=__version__,
        packages=setuptools.find_packages(
            include=["transformer_engine",
                     "transformer_engine.*",
                     "transformer_engine/build_tools"],
        ),
        extras_require={
            "test": test_requires,
        },
Tim Moon's avatar
Tim Moon committed
130
131
132
133
134
135
        description="Transformer acceleration library",
        ext_modules=ext_modules,
        cmdclass={"build_ext": CMakeBuildExtension},
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
136
137
        include_package_data=True,
        package_data={"": ["VERSION.txt"]}
Tim Moon's avatar
Tim Moon committed
138
    )