setup.py 5.86 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
8
import time
Tim Moon's avatar
Tim Moon committed
9
from pathlib import Path
10
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
11

Tim Moon's avatar
Tim Moon committed
12
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
13
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
14

15
16
17
18
19
20
21
22
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.utils import (
    found_cmake,
    found_ninja,
    found_pybind11,
    remove_dups,
    get_frameworks,
    install_and_import,
23
    uninstall_te_fw_packages,
24
25
)
from build_tools.te_version import te_version
26

Przemek Tredak's avatar
Przemek Tredak committed
27

28
29
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
30
31


32
from setuptools.command.build_ext import build_ext as BuildExtension
33

34
35
os.environ["NVTE_PROJECT_BUILDING"] = "1"

36
37
38
39
40
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
    from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
41
    install_and_import("pybind11[global]")
42
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
43
44


45
CMakeBuildExtension = get_build_ext(BuildExtension)
46

Tim Moon's avatar
Tim Moon committed
47

Phuong Nguyen's avatar
Phuong Nguyen committed
48
49
50
51
52
53
54
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
55
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
56
57


58
def setup_common_extension() -> CMakeExtension:
59
    """Setup CMake extension for common library"""
60
61
62
63
    # Project directory root
    root_path = Path(__file__).resolve().parent
    return CMakeExtension(
        name="transformer_engine",
64
65
        cmake_path=root_path / Path("transformer_engine/common"),
        cmake_flags=[],
Tim Moon's avatar
Tim Moon committed
66
67
68
69
70
71
72
73
74
75
76
    )


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
    """Setup Python dependencies

    Returns dependencies for build, runtime, and testing.
    """

    # Common requirements
    setup_reqs: List[str] = []
77
78
    install_reqs: List[str] = [
        "pydantic",
79
        "importlib-metadata>=1.0",
80
        "packaging",
81
    ]
82
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
83
84
85

    # Requirements that may be installed outside of Python
    if not found_cmake():
86
        setup_reqs.append("cmake>=3.21")
Tim Moon's avatar
Tim Moon committed
87
    if not found_ninja():
88
89
90
        setup_reqs.append("ninja")
    if not found_pybind11():
        setup_reqs.append("pybind11")
Przemek Tredak's avatar
Przemek Tredak committed
91

92
93
94
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
95
            install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
96
97
98
99
100
101
102
103
            test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
        if "jax" in frameworks:
            install_reqs.extend(["jax", "flax>=0.7.1"])
            test_reqs.extend(["numpy", "praxis"])
        if "paddle" in frameworks:
            install_reqs.append("paddlepaddle-gpu")
            test_reqs.append("numpy")

104
    return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
105

106

107
if __name__ == "__main__":
Tim Moon's avatar
Tim Moon committed
108
109
    # Dependencies
    setup_requires, install_requires, test_requires = setup_requirements()
110

111
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
112

113
114
    ext_modules = [setup_common_extension()]
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
115
116
117
        # Remove residual FW packages since compiling from source
        # results in a single binary with FW extensions included.
        uninstall_te_fw_packages()
118
119
        if "pytorch" in frameworks:
            from build_tools.pytorch import setup_pytorch_extension
120

121
122
123
124
            ext_modules.append(
                setup_pytorch_extension(
                    "transformer_engine/pytorch/csrc",
                    current_file_path / "transformer_engine" / "pytorch" / "csrc",
125
126
127
                    current_file_path / "transformer_engine",
                )
            )
128
129
        if "jax" in frameworks:
            from build_tools.jax import setup_jax_extension
130

131
132
133
134
            ext_modules.append(
                setup_jax_extension(
                    "transformer_engine/jax/csrc",
                    current_file_path / "transformer_engine" / "jax" / "csrc",
135
136
137
                    current_file_path / "transformer_engine",
                )
            )
138
139
        if "paddle" in frameworks:
            from build_tools.paddle import setup_paddle_extension
140

141
142
143
144
            ext_modules.append(
                setup_paddle_extension(
                    "transformer_engine/paddle/csrc",
                    current_file_path / "transformer_engine" / "paddle" / "csrc",
145
146
147
                    current_file_path / "transformer_engine",
                )
            )
148

Tim Moon's avatar
Tim Moon committed
149
150
151
    # Configure package
    setuptools.setup(
        name="transformer_engine",
152
153
        version=__version__,
        packages=setuptools.find_packages(
154
155
156
157
158
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
159
160
161
        ),
        extras_require={
            "test": test_requires,
162
163
164
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
            "paddle": [f"transformer_engine_paddle=={__version__}"],
165
        },
Tim Moon's avatar
Tim Moon committed
166
167
        description="Transformer acceleration library",
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
168
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
169
170
171
172
173
174
175
176
        python_requires=">=3.8, <3.13",
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
        ],
Tim Moon's avatar
Tim Moon committed
177
178
179
        setup_requires=setup_requires,
        install_requires=install_requires,
        license_files=("LICENSE",),
180
        include_package_data=True,
181
        package_data={"": ["VERSION.txt"]},
Tim Moon's avatar
Tim Moon committed
182
    )