setup.py 4.94 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
Tim Moon's avatar
Tim Moon committed
8
from pathlib import Path
9
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
10

Tim Moon's avatar
Tim Moon committed
11
import setuptools
Przemek Tredak's avatar
Przemek Tredak committed
12

13
14
15
16
17
18
19
20
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.utils import (
    found_cmake,
    found_ninja,
    found_pybind11,
    remove_dups,
    get_frameworks,
    install_and_import,
21
    uninstall_te_fw_packages,
22
23
)
from build_tools.te_version import te_version
24

Przemek Tredak's avatar
Przemek Tredak committed
25

26
27
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
28
29


30
from setuptools.command.build_ext import build_ext as BuildExtension
31

32
33
os.environ["NVTE_PROJECT_BUILDING"] = "1"

34
35
36
37
38
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
    from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
39
    install_and_import("pybind11[global]")
40
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
41
42


43
CMakeBuildExtension = get_build_ext(BuildExtension)
44

Tim Moon's avatar
Tim Moon committed
45

46
def setup_common_extension() -> CMakeExtension:
47
    """Setup CMake extension for common library"""
48
49
50
51
    # Project directory root
    root_path = Path(__file__).resolve().parent
    return CMakeExtension(
        name="transformer_engine",
52
53
        cmake_path=root_path / Path("transformer_engine/common"),
        cmake_flags=[],
Tim Moon's avatar
Tim Moon committed
54
55
56
57
58
59
60
61
62
63
64
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
65
66
    install_reqs: List[str] = [
        "pydantic",
67
        "importlib-metadata>=1.0",
68
        "packaging",
69
    ]
70
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
71
72
73

    # Requirements that may be installed outside of Python
    if not found_cmake():
74
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
75
    if not found_ninja():
76
77
78
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
79

80
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
81

82

83
if __name__ == "__main__":
Tim Moon's avatar
Tim Moon committed
84
85
    # Dependencies
    setup_requires, install_requires, test_requires = setup_requirements()
86

87
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
88

89
90
    ext_modules = [setup_common_extension()]
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
91
92
93
        # Remove residual FW packages since compiling from source
        # results in a single binary with FW extensions included.
        uninstall_te_fw_packages()
94
95
        if "pytorch" in frameworks:
            from build_tools.pytorch import setup_pytorch_extension
96

97
98
99
100
            ext_modules.append(
                setup_pytorch_extension(
                    "transformer_engine/pytorch/csrc",
                    current_file_path / "transformer_engine" / "pytorch" / "csrc",
101
102
103
                    current_file_path / "transformer_engine",
                )
            )
104
105
        if "jax" in frameworks:
            from build_tools.jax import setup_jax_extension
106

107
108
109
110
            ext_modules.append(
                setup_jax_extension(
                    "transformer_engine/jax/csrc",
                    current_file_path / "transformer_engine" / "jax" / "csrc",
111
112
113
                    current_file_path / "transformer_engine",
                )
            )
114
115
        if "paddle" in frameworks:
            from build_tools.paddle import setup_paddle_extension
116

117
118
119
120
            ext_modules.append(
                setup_paddle_extension(
                    "transformer_engine/paddle/csrc",
                    current_file_path / "transformer_engine" / "paddle" / "csrc",
121
122
123
                    current_file_path / "transformer_engine",
                )
            )
124

Tim Moon's avatar
Tim Moon committed
125
126
127
    # Configure package
    setuptools.setup(
        name="transformer_engine",
128
129
        version=__version__,
        packages=setuptools.find_packages(
130
131
132
133
134
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
135
136
137
        ),
        extras_require={
            "test": test_requires,
138
139
140
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
            "paddle": [f"transformer_engine_paddle=={__version__}"],
141
        },
Tim Moon's avatar
Tim Moon committed
142
143
144
        description="Transformer acceleration library",
        ext_modules=ext_modules,
        cmdclass={"build_ext": CMakeBuildExtension},
145
146
147
148
149
150
151
152
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
153
154
155
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
156
        include_package_data=True,
157
        package_data={"": ["VERSION.txt"]},
Tim Moon's avatar
Tim Moon committed
158
    )