setup.py 7.03 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
6
"""Installation script."""

7
from importlib import metadata
Przemek Tredak's avatar
Przemek Tredak committed
8
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
9
import time
Tim Moon's avatar
Tim Moon committed
10
from pathlib import Path
11
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
12

Tim Moon's avatar
Tim Moon committed
13
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
14
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
15

16
from build_tools.build_ext import CMakeExtension, get_build_ext
17
from build_tools.te_version import te_version
18
from build_tools.utils import (
19
    cuda_archs,
20
    cuda_version,
21
    get_frameworks,
22
    remove_dups,
23
    min_python_version_str,
24
)
Przemek Tredak's avatar
Przemek Tredak committed
25

26
27
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
28
29


30
from setuptools.command.build_ext import build_ext as BuildExtension
31

32
33
os.environ["NVTE_PROJECT_BUILDING"] = "1"

34
35
36
37
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
38
39


40
CMakeBuildExtension = get_build_ext(BuildExtension)
41
archs = cuda_archs()
42

Tim Moon's avatar
Tim Moon committed
43

Phuong Nguyen's avatar
Phuong Nguyen committed
44
45
46
47
48
49
50
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
51
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
52
53


54
def setup_common_extension() -> CMakeExtension:
55
    """Setup CMake extension for common library"""
56
    cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
57
58
59
60
61
62
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

63
64
65
66
67
68
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

69
70
71
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

72
73
74
    if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
        cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
        cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
75
76
            f"nvidia-cublasmp-cu{cuda_version()[0]}"
        ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
77
78
        cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
        nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
79
            f"nvidia-nvshmem-cu{cuda_version()[0]}"
80
81
82
83
        ).locate_file("nvidia/nvshmem")
        cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
        print("CMAKE_FLAGS:", cmake_flags[-2:])

84
85
86
87
88
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

89
90
    # Project directory root
    root_path = Path(__file__).resolve().parent
91

92
93
    return CMakeExtension(
        name="transformer_engine",
94
        cmake_path=root_path / Path("transformer_engine/common"),
95
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
96
97
98
    )


99
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
100
101
    """Setup Python dependencies

102
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
103
104
105
    """

    # Common requirements
106
107
    install_reqs: List[str] = [
        "pydantic",
108
        "importlib-metadata>=1.0",
109
        "packaging",
110
    ]
111
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
112

113
114
115
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
116
117
118
119
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
120
        if "jax" in frameworks:
121
122
123
124
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
125

126
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
127

128

129
130
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
131

132
133
134
135
136
137
138
139
140
141
142
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
143
        install_requires = []
144
        extras_require = {
145
146
147
            "core": [f"transformer_engine_cu12=={__version__}"],
            "core_cu12": [f"transformer_engine_cu12=={__version__}"],
            "core_cu13": [f"transformer_engine_cu13=={__version__}"],
148
149
150
151
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
152
        install_requires, test_requires = setup_requirements()
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
168
                )
169
170
171
172
173
174
175
176
177
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
178
                )
179

Tim Moon's avatar
Tim Moon committed
180
181
182
    # Configure package
    setuptools.setup(
        name="transformer_engine",
183
184
        version=__version__,
        packages=setuptools.find_packages(
185
186
187
188
189
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
190
        ),
191
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
192
        description="Transformer acceleration library",
193
194
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
195
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
196
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
197
        python_requires=f">={min_python_version_str()}",
198
        classifiers=["Programming Language :: Python :: 3"],
Tim Moon's avatar
Tim Moon committed
199
200
        install_requires=install_requires,
        license_files=("LICENSE",),
201
202
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
203
    )