setup.py 4.19 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
Tim Moon's avatar
Tim Moon committed
8
from pathlib import Path
9
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
10

Tim Moon's avatar
Tim Moon committed
11
import setuptools
Przemek Tredak's avatar
Przemek Tredak committed
12

13
14
15
16
17
18
19
20
21
22
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.utils import (
    found_cmake,
    found_ninja,
    found_pybind11,
    remove_dups,
    get_frameworks,
    install_and_import,
)
from build_tools.te_version import te_version
23

Przemek Tredak's avatar
Przemek Tredak committed
24

25
26
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
27
28


29
from setuptools.command.build_ext import build_ext as BuildExtension
30

31
32
33
34
35
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
    from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
36
    install_and_import("pybind11")
37
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
38
39


40
CMakeBuildExtension = get_build_ext(BuildExtension)
41

Tim Moon's avatar
Tim Moon committed
42

43
def setup_common_extension() -> CMakeExtension:
44
    """Setup CMake extension for common library"""
45
46
47
48
    # Project directory root
    root_path = Path(__file__).resolve().parent
    return CMakeExtension(
        name="transformer_engine",
49
50
        cmake_path=root_path / Path("transformer_engine/common"),
        cmake_flags=[],
Tim Moon's avatar
Tim Moon committed
51
52
53
54
55
56
57
58
59
60
61
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
62
63
64
    install_reqs: List[str] = [
        "pydantic",
        "importlib-metadata>=1.0; python_version<'3.8'",
65
        "packaging",
66
    ]
67
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
68
69
70

    # Requirements that may be installed outside of Python
    if not found_cmake():
71
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
72
    if not found_ninja():
73
74
75
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
76

77
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
78

79

80
if __name__ == "__main__":
Tim Moon's avatar
Tim Moon committed
81
82
    # Dependencies
    setup_requires, install_requires, test_requires = setup_requirements()
83

84
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
85

86
87
88
89
    ext_modules = [setup_common_extension()]
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
            from build_tools.pytorch import setup_pytorch_extension
90

91
92
93
94
            ext_modules.append(
                setup_pytorch_extension(
                    "transformer_engine/pytorch/csrc",
                    current_file_path / "transformer_engine" / "pytorch" / "csrc",
95
96
97
                    current_file_path / "transformer_engine",
                )
            )
98
99
        if "jax" in frameworks:
            from build_tools.jax import setup_jax_extension
100

101
102
103
104
            ext_modules.append(
                setup_jax_extension(
                    "transformer_engine/jax/csrc",
                    current_file_path / "transformer_engine" / "jax" / "csrc",
105
106
107
                    current_file_path / "transformer_engine",
                )
            )
108
109
        if "paddle" in frameworks:
            from build_tools.paddle import setup_paddle_extension
110

111
112
113
114
            ext_modules.append(
                setup_paddle_extension(
                    "transformer_engine/paddle/csrc",
                    current_file_path / "transformer_engine" / "paddle" / "csrc",
115
116
117
                    current_file_path / "transformer_engine",
                )
            )
118

Tim Moon's avatar
Tim Moon committed
119
120
121
    # Configure package
    setuptools.setup(
        name="transformer_engine",
122
123
        version=__version__,
        packages=setuptools.find_packages(
124
125
126
127
128
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
129
130
131
132
        ),
        extras_require={
            "test": test_requires,
        },
Tim Moon's avatar
Tim Moon committed
133
134
135
136
137
138
        description="Transformer acceleration library",
        ext_modules=ext_modules,
        cmdclass={"build_ext": CMakeBuildExtension},
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
139
        include_package_data=True,
140
        package_data={"": ["VERSION.txt"]},
Tim Moon's avatar
Tim Moon committed
141
    )