setup.py 5.49 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
8
import time
Tim Moon's avatar
Tim Moon committed
9
from pathlib import Path
10
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
11

Tim Moon's avatar
Tim Moon committed
12
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
13
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
14

15
16
17
18
19
20
21
22
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.utils import (
    found_cmake,
    found_ninja,
    found_pybind11,
    remove_dups,
    get_frameworks,
    install_and_import,
23
    uninstall_te_fw_packages,
24
25
)
from build_tools.te_version import te_version
26

Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
30
31


32
from setuptools.command.build_ext import build_ext as BuildExtension
33

34
35
os.environ["NVTE_PROJECT_BUILDING"] = "1"

36
37
38
39
40
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
    from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
41
    install_and_import("pybind11[global]")
42
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
43

Phuong Nguyen's avatar
Phuong Nguyen committed
44
45
46
# Start timing
start_time = time.perf_counter()

Tim Moon's avatar
Tim Moon committed
47

48
CMakeBuildExtension = get_build_ext(BuildExtension)
49

Tim Moon's avatar
Tim Moon committed
50

Phuong Nguyen's avatar
Phuong Nguyen committed
51
52
53
54
55
56
57
58
59
60
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
        print(f"Time for bdist_wheel: {total_time:.2f} seconds")


61
def setup_common_extension() -> CMakeExtension:
62
    """Setup CMake extension for common library"""
63
64
65
66
    # Project directory root
    root_path = Path(__file__).resolve().parent
    return CMakeExtension(
        name="transformer_engine",
67
68
        cmake_path=root_path / Path("transformer_engine/common"),
        cmake_flags=[],
Tim Moon's avatar
Tim Moon committed
69
70
71
72
73
74
75
76
77
78
79
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
80
81
    install_reqs: List[str] = [
        "pydantic",
82
        "importlib-metadata>=1.0",
83
        "packaging",
84
    ]
85
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
86
87
88

    # Requirements that may be installed outside of Python
    if not found_cmake():
89
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
90
    if not found_ninja():
91
92
93
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
94

95
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
96

97

98
if __name__ == "__main__":
Tim Moon's avatar
Tim Moon committed
99
100
    # Dependencies
    setup_requires, install_requires, test_requires = setup_requirements()
101

102
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
103

104
105
    ext_modules = [setup_common_extension()]
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
106
107
108
        # Remove residual FW packages since compiling from source
        # results in a single binary with FW extensions included.
        uninstall_te_fw_packages()
109
110
        if "pytorch" in frameworks:
            from build_tools.pytorch import setup_pytorch_extension
111

112
113
114
115
            ext_modules.append(
                setup_pytorch_extension(
                    "transformer_engine/pytorch/csrc",
                    current_file_path / "transformer_engine" / "pytorch" / "csrc",
116
117
118
                    current_file_path / "transformer_engine",
                )
            )
119
120
        if "jax" in frameworks:
            from build_tools.jax import setup_jax_extension
121

122
123
124
125
            ext_modules.append(
                setup_jax_extension(
                    "transformer_engine/jax/csrc",
                    current_file_path / "transformer_engine" / "jax" / "csrc",
126
127
128
                    current_file_path / "transformer_engine",
                )
            )
129
130
        if "paddle" in frameworks:
            from build_tools.paddle import setup_paddle_extension
131

132
133
134
135
            ext_modules.append(
                setup_paddle_extension(
                    "transformer_engine/paddle/csrc",
                    current_file_path / "transformer_engine" / "paddle" / "csrc",
136
137
138
                    current_file_path / "transformer_engine",
                )
            )
139

Tim Moon's avatar
Tim Moon committed
140
141
142
    # Configure package
    setuptools.setup(
        name="transformer_engine",
143
144
        version=__version__,
        packages=setuptools.find_packages(
145
146
147
148
149
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
150
151
152
        ),
        extras_require={
            "test": test_requires,
153
154
155
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
            "paddle": [f"transformer_engine_paddle=={__version__}"],
156
        },
Tim Moon's avatar
Tim Moon committed
157
158
        description="Transformer acceleration library",
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
159
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
160
161
162
163
164
165
166
167
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
168
169
170
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
171
        include_package_data=True,
172
        package_data={"": ["VERSION.txt"]},
Tim Moon's avatar
Tim Moon committed
173
    )
Phuong Nguyen's avatar
Phuong Nguyen committed
174
175
176
177
178

    # End timing
    end_time = time.perf_counter()
    total_time = end_time - start_time
    print(f"Total build time: {total_time:.2f} seconds")