setup.py 7.42 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
#
# See LICENSE for license information.

5
"""Installation script."""
yuguo's avatar
yuguo committed
6
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
yuguo's avatar
yuguo committed
7
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
8

Przemek Tredak's avatar
Przemek Tredak committed
9
import os
Phuong Nguyen's avatar
Phuong Nguyen committed
10
import time
Tim Moon's avatar
Tim Moon committed
11
from pathlib import Path
12
from typing import List, Tuple
Przemek Tredak's avatar
Przemek Tredak committed
13

Tim Moon's avatar
Tim Moon committed
14
import setuptools
Phuong Nguyen's avatar
Phuong Nguyen committed
15
from wheel.bdist_wheel import bdist_wheel
Przemek Tredak's avatar
Przemek Tredak committed
16

17
from build_tools.build_ext import CMakeExtension, get_build_ext
18
from build_tools.te_version import te_version
19
from build_tools.utils import (
yuguo's avatar
yuguo committed
20
    rocm_build,
21
    cuda_archs,
22
    get_frameworks,
23
    remove_dups,
24
)
Przemek Tredak's avatar
Przemek Tredak committed
25

26
27
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
Przemek Tredak's avatar
Przemek Tredak committed
28
29


30
from setuptools.command.build_ext import build_ext as BuildExtension
31

32
33
os.environ["NVTE_PROJECT_BUILDING"] = "1"

34
35
36
37
if "pytorch" in frameworks:
    from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
    from pybind11.setup_helpers import build_ext as BuildExtension
Tim Moon's avatar
Tim Moon committed
38
39


40
CMakeBuildExtension = get_build_ext(BuildExtension)
yuguo's avatar
yuguo committed
41
42
43
44
if rocm_build():
    archs = None
else:
    archs = cuda_archs()
45

Tim Moon's avatar
Tim Moon committed
46

Phuong Nguyen's avatar
Phuong Nguyen committed
47
48
49
50
51
52
53
class TimedBdist(bdist_wheel):
    """Helper class to measure build time"""

    def run(self):
        start_time = time.perf_counter()
        super().run()
        total_time = time.perf_counter() - start_time
54
        print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
Phuong Nguyen's avatar
Phuong Nguyen committed
55
56


57
def setup_common_extension() -> CMakeExtension:
58
    """Setup CMake extension for common library"""
yuguo's avatar
yuguo committed
59
60
    if rocm_build():
        cmake_flags = []
61
62
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
63
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
64
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
65
        if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
66
67
            cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
            
yuguo's avatar
yuguo committed
68
69
    else:
        cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
70
71
72
73
74
75
    if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
        assert (
            os.getenv("MPI_HOME") is not None
        ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
        cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

76
77
78
79
80
81
    if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
        assert (
            os.getenv("NVSHMEM_HOME") is not None
        ), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
        cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")

82
83
84
    if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
        cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

85
86
87
88
89
    # Add custom CMake arguments from environment variable
    nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
    if nvte_cmake_extra_args:
        cmake_flags.extend(nvte_cmake_extra_args.split())

90
91
    # Project directory root
    root_path = Path(__file__).resolve().parent
yuguo's avatar
yuguo committed
92
93
94
95
96
    if rocm_build():
        if os.getenv("NVTE_USE_HIPBLASLT") is not None:
            cmake_flags.append("-DUSE_HIPBLASLT=ON")
        if os.getenv("NVTE_USE_ROCBLAS") is not None:
            cmake_flags.append("-DUSE_ROCBLAS=ON")
97

98
99
    return CMakeExtension(
        name="transformer_engine",
100
        cmake_path=root_path / Path("transformer_engine/common"),
101
        cmake_flags=cmake_flags,
Tim Moon's avatar
Tim Moon committed
102
103
104
    )


105
def setup_requirements() -> Tuple[List[str], List[str]]:
Tim Moon's avatar
Tim Moon committed
106
107
    """Setup Python dependencies

108
    Returns dependencies for runtime and testing.
Tim Moon's avatar
Tim Moon committed
109
110
111
    """

    # Common requirements
112
113
    install_reqs: List[str] = [
        "pydantic",
114
        "importlib-metadata>=1.0",
115
        "packaging",
116
    ]
117
    test_reqs: List[str] = ["pytest>=8.2.1"]
Tim Moon's avatar
Tim Moon committed
118

119
120
121
    # Framework-specific requirements
    if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
        if "pytorch" in frameworks:
122
123
124
125
            from build_tools.pytorch import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
126
        if "jax" in frameworks:
127
128
129
130
            from build_tools.jax import install_requirements, test_requirements

            install_reqs.extend(install_requirements())
            test_reqs.extend(test_requirements())
131

132
    return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
Tim Moon's avatar
Tim Moon committed
133

134

135
136
if __name__ == "__main__":
    __version__ = te_version()
Przemek Tredak's avatar
Przemek Tredak committed
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    with open("README.rst", encoding="utf-8") as f:
        long_description = f.read()

    # Settings for building top level empty package for dependency management.
    if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
        assert bool(
            int(os.getenv("NVTE_RELEASE_BUILD", "0"))
        ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
        ext_modules = []
        package_data = {}
        include_package_data = False
        install_requires = ([f"transformer_engine_cu12=={__version__}"],)
        extras_require = {
            "pytorch": [f"transformer_engine_torch=={__version__}"],
            "jax": [f"transformer_engine_jax=={__version__}"],
        }
    else:
155
        install_requires, test_requires = setup_requirements()
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        ext_modules = [setup_common_extension()]
        package_data = {"": ["VERSION.txt"]}
        include_package_data = True
        extras_require = {"test": test_requires}

        if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
            if "pytorch" in frameworks:
                from build_tools.pytorch import setup_pytorch_extension

                ext_modules.append(
                    setup_pytorch_extension(
                        "transformer_engine/pytorch/csrc",
                        current_file_path / "transformer_engine" / "pytorch" / "csrc",
                        current_file_path / "transformer_engine",
                    )
171
                )
172
173
174
175
176
177
178
179
180
            if "jax" in frameworks:
                from build_tools.jax import setup_jax_extension

                ext_modules.append(
                    setup_jax_extension(
                        "transformer_engine/jax/csrc",
                        current_file_path / "transformer_engine" / "jax" / "csrc",
                        current_file_path / "transformer_engine",
                    )
181
                )
182

Tim Moon's avatar
Tim Moon committed
183
184
185
    # Configure package
    setuptools.setup(
        name="transformer_engine",
186
187
        version=__version__,
        packages=setuptools.find_packages(
188
189
190
191
192
            include=[
                "transformer_engine",
                "transformer_engine.*",
                "transformer_engine/build_tools",
            ],
193
        ),
194
        extras_require=extras_require,
Tim Moon's avatar
Tim Moon committed
195
        description="Transformer acceleration library",
196
197
        long_description=long_description,
        long_description_content_type="text/x-rst",
Tim Moon's avatar
Tim Moon committed
198
        ext_modules=ext_modules,
Phuong Nguyen's avatar
Phuong Nguyen committed
199
        cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
200
201
        python_requires=">=3.8",
        classifiers=["Programming Language :: Python :: 3"],
Tim Moon's avatar
Tim Moon committed
202
203
        install_requires=install_requires,
        license_files=("LICENSE",),
204
205
        include_package_data=include_package_data,
        package_data=package_data,
Tim Moon's avatar
Tim Moon committed
206
    )