setup.py 6.15 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

Przemek Tredak's avatar
Przemek Tredak committed
7
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
8
import time
Tim Moon's avatar
Tim Moon committed
9
from pathlib import Path
10
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
11

Tim Moon's avatar
Tim Moon committed
12
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
13
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
14

15
from build_tools.build_ext import CMakeExtension, get_build_ext
16
from build_tools.te_version import te_version
17
from build_tools.utils import (
18
    cuda_archs,
19
    get_frameworks,
20
    remove_dups,
21
)
Przemek Tredak's avatar
Przemek Tredak committed
22

23
24
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
25
26


27
from setuptools.command.build_ext import build_ext as BuildExtension
28

29
30
os.environ["NVTE_PROJECT_BUILDING"] = "1"

31
32
33
34
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
35
36


37
CMakeBuildExtension = get_build_ext(BuildExtension)
38
archs = cuda_archs()
39

Tim Moon's avatar
Tim Moon committed
40

Phuong Nguyen's avatar
Phuong Nguyen committed
41
42
43
44
45
46
47
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
48
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
49
50


51
def setup_common_extension() -> CMakeExtension:
52
    """Setup CMake extension for common library"""
53
    cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
54
55
56
57
58
59
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

60
61
62
63
64
65
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

66
67
68
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

69
70
71
72
73
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

74
75
    # Project directory root
    root_path = Path(__file__).resolve().parent
76

77
78
    return CMakeExtension(
        name="transformer_engine",
79
        cmake_path=root_path / Path("transformer_engine/common"),
80
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
81
82
83
    )


84
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
85
86
    """Setup Python dependencies

87
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
88
89
90
    """

    # Common requirements
91
92
    install_reqs: List[str] = [
        "pydantic",
93
        "importlib-metadata>=1.0",
94
        "packaging",
95
    ]
96
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
97

98
99
100
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
101
102
103
104
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
105
        if "jax" in frameworks:
106
107
108
109
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
110

111
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
112

113

114
115
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
134
        install_requires, test_requires = setup_requirements()
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
150
                )
151
152
153
154
155
156
157
158
159
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
160
                )
161

Tim Moon's avatar
Tim Moon committed
162
163
164
    # Configure package
    setuptools.setup(
        name="transformer_engine",
165
166
        version=__version__,
        packages=setuptools.find_packages(
167
168
169
170
171
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
172
        ),
173
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
174
        description="Transformer acceleration library",
175
176
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
177
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
178
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
179
180
        python_requires=">=3.8",
        classifiers=["Programming Language :: Python :: 3"],
Tim Moon's avatar
Tim Moon committed
181
182
        install_requires=install_requires,
        license_files=("LICENSE",),
183
184
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
185
    )