setup.py 10.3 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
2
import logging
Tri Dao's avatar
Tri Dao committed
3
4
import sys
import os
5
6
import re
import ast
7
from collections import namedtuple
Tri Dao's avatar
Tri Dao committed
8
from pathlib import Path
9
10
11
from typing import Dict
from shutil import which
from packaging.version import Version, parse
Tri Dao's avatar
Tri Dao committed
12
13
14
15

import subprocess

import torch
16
17
18
19
from torch.utils.cpp_extension import (
    BuildExtension,
    CUDA_HOME,
)
Tri Dao's avatar
Tri Dao committed
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext

logger = logging.getLogger(__name__)

# Enivronment variables
Envs = namedtuple("Envs", ["VERBOSE", "MAX_JOBS", "NVCC_THREADS", "VLLM_TARGET_DEVICE", "CMAKE_BUILD_TYPE"])
envs = Envs(
    VERBOSE=bool(int(os.getenv("VERBOSE", "0"))),
    MAX_JOBS=os.getenv("MAX_JOBS"),
    NVCC_THREADS=os.getenv("NVCC_THREADS"),
    VLLM_TARGET_DEVICE=os.getenv("VLLM_TARGET_DEVICE", "cuda"),
    CMAKE_BUILD_TYPE=os.getenv("CMAKE_BUILD_TYPE"),
)
Tri Dao's avatar
Tri Dao committed
35
36
37
38

with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()

Tri Dao's avatar
Tri Dao committed
39
40
41
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

Woosuk Kwon's avatar
Woosuk Kwon committed
42
PACKAGE_NAME = "vllm_flash_attn"
Tri Dao's avatar
Tri Dao committed
43

44
45
cmdclass = {}
ext_modules = []
46

47
48
# TODO(luka): This should be replaced with a fetch_content call in CMakeLists.txt
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
49
50


51
52
def is_sccache_available() -> bool:
    return which("sccache") is not None
53

Tri Dao's avatar
Tri Dao committed
54

55
56
def is_ccache_available() -> bool:
    return which("ccache") is not None
Tri Dao's avatar
Tri Dao committed
57
58


59
60
def is_ninja_available() -> bool:
    return which("ninja") is not None
Tri Dao's avatar
Tri Dao committed
61
62


63
64
65
66
def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix):]
    return text
Tri Dao's avatar
Tri Dao committed
67

68
69

VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
Tri Dao's avatar
Tri Dao committed
70
71


72
73
74
75
76
77
78
79
def _is_cuda() -> bool:
    has_cuda = torch.version.cuda is not None
    return VLLM_TARGET_DEVICE == "cuda" and has_cuda


def _is_hip() -> bool:
    return (VLLM_TARGET_DEVICE == "cuda"
            or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
Tri Dao's avatar
Tri Dao committed
80

Tri Dao's avatar
Tri Dao committed
81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class CMakeExtension(Extension):

    def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
        super().__init__(name, sources=[], py_limited_api=True, **kwa)
        self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class cmake_build_ext(build_ext):
    # A dict of extension directories that have been configured.
    did_config: Dict[str, bool] = {}

    #
    # Determine number of compilation jobs and optionally nvcc compile threads.
    #
    def compute_num_jobs(self):
        # `num_jobs` is either the value of the MAX_JOBS environment variable
        # (if defined) or the number of CPUs available.
        num_jobs = envs.MAX_JOBS
        if num_jobs is not None:
            num_jobs = int(num_jobs)
            logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
        else:
            try:
                # os.sched_getaffinity() isn't universally available, so fall
                #  back to os.cpu_count() if we get an error here.
                num_jobs = len(os.sched_getaffinity(0))
            except AttributeError:
                num_jobs = os.cpu_count()

        nvcc_threads = None
        if _is_cuda() and get_nvcc_cuda_version() >= Version("11.2"):
            # `nvcc_threads` is either the value of the NVCC_THREADS
            # environment variable (if defined) or 1.
            # when it is set, we reduce `num_jobs` to avoid
            # overloading the system.
            nvcc_threads = envs.NVCC_THREADS
            if nvcc_threads is not None:
                nvcc_threads = int(nvcc_threads)
                logger.info(
                    "Using NVCC_THREADS=%d as the number of nvcc threads.",
                    nvcc_threads)
            else:
                nvcc_threads = 1
            num_jobs = max(1, num_jobs // nvcc_threads)

        return num_jobs, nvcc_threads

    #
    # Perform cmake configuration for a single extension.
    #
    def configure(self, ext: CMakeExtension) -> None:
        # If we've already configured using the CMakeLists.txt for
        # this extension, exit early.
        if ext.cmake_lists_dir in cmake_build_ext.did_config:
            return

        cmake_build_ext.did_config[ext.cmake_lists_dir] = True

        # Select the build type.
        # Note: optimization level + debug info are set by the build type
        default_cfg = "Debug" if self.debug else "RelWithDebInfo"
        cfg = envs.CMAKE_BUILD_TYPE or default_cfg

        cmake_args = [
            '-DCMAKE_BUILD_TYPE={}'.format(cfg),
            '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
        ]

        verbose = envs.VERBOSE
        if verbose:
            cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON']

        if is_sccache_available():
            cmake_args += [
                '-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
                '-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
                '-DCMAKE_C_COMPILER_LAUNCHER=sccache',
            ]
        elif is_ccache_available():
            cmake_args += [
                '-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
                '-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
            ]

        # Pass the python executable to cmake so it can find an exact
        # match.
        cmake_args += ['-DPython_EXECUTABLE={}'.format(sys.executable)]

        # Pass the python path to cmake so it can reuse the build dependencies
        # on subsequent calls to python.
        cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))]

        #
        # Setup parallelism and build tool
        #
        num_jobs, nvcc_threads = self.compute_num_jobs()

        if nvcc_threads:
            cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)]

        if is_ninja_available():
            build_tool = ['-G', 'Ninja']
            cmake_args += [
                '-DCMAKE_JOB_POOL_COMPILE:STRING=compile',
                '-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs),
            ]
        else:
            # Default build tool to whatever cmake picks.
            build_tool = []
        subprocess.check_call(
            ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
            cwd=self.build_temp)

    def build_extensions(self) -> None:
        # Ensure that CMake is present and working
        try:
            subprocess.check_output(['cmake', '--version'])
        except OSError as e:
            raise RuntimeError('Cannot find CMake executable') from e

        # Create build directory if it does not exist.
        if not os.path.exists(self.build_temp):
            os.makedirs(self.build_temp)

        targets = []
        target_name = lambda s: remove_prefix(s, "vllm_flash_attn.")
        # Build all the extensions
        for ext in self.extensions:
            self.configure(ext)
            targets.append(target_name(ext.name))

        num_jobs, _ = self.compute_num_jobs()

        build_args = [
            "--build",
            ".",
            f"-j={num_jobs}",
            *[f"--target={name}" for name in targets],
        ]

        subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)

        # Install the libraries
        for ext in self.extensions:
            # Install the extension into the proper location
            outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()

            # Skip if the install directory is the same as the build directory
            if outdir == self.build_temp:
                continue

            # CMake appends the extension prefix to the install path,
            # and outdir already contains that prefix, so we need to remove it.
            prefix = outdir
            for i in range(ext.name.count('.')):
                prefix = prefix.parent

            # prefix here should actually be the same for all components
            install_args = [
                "cmake", "--install", ".", "--prefix", prefix, "--component",
                target_name(ext.name)
            ]
            subprocess.check_call(install_args, cwd=self.build_temp)
Tri Dao's avatar
Tri Dao committed
245

Tri Dao's avatar
Tri Dao committed
246

247
def get_package_version():
Woosuk Kwon's avatar
Woosuk Kwon committed
248
    with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
249
250
251
252
253
254
255
256
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)

Tri Dao's avatar
Tri Dao committed
257

258
259
PYTORCH_VERSION = "2.4.0"
MAIN_CUDA_VERSION = "12.1"
260
261


262
263
def get_nvcc_cuda_version() -> Version:
    """Get the CUDA version from nvcc.
264

265
266
267
268
269
270
271
272
273
    Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
    """
    assert CUDA_HOME is not None, "CUDA_HOME is not set"
    nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
                                          universal_newlines=True)
    output = nvcc_output.split()
    release_idx = output.index("release") + 1
    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
    return nvcc_cuda_version
274
275
276
277


def get_version() -> str:
    version = get_package_version()
278
    cuda_version = str(get_nvcc_cuda_version())
279
280
281
282
283
    if cuda_version != MAIN_CUDA_VERSION:
        cuda_version_str = cuda_version.replace(".", "")[:3]
        version += f"+cu{cuda_version_str}"
    return version

Woosuk Kwon's avatar
Woosuk Kwon committed
284

285
286
ext_modules.append(CMakeExtension(name="vllm_flash_attn.vllm_flash_attn_c"))

Tri Dao's avatar
Tri Dao committed
287
setup(
Woosuk Kwon's avatar
Woosuk Kwon committed
288
    name="vllm-flash-attn",
289
    version=get_version(),
290
291
292
293
294
295
296
297
    packages=find_packages(exclude=("build",
                                    "csrc",
                                    "include",
                                    "tests",
                                    "dist",
                                    "docs",
                                    "benchmarks",
                                    f"{PACKAGE_NAME}.egg-info",)),
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
    author="vLLM Team",
    description="Forward-only flash-attn",
300
    long_description=f"Forward-only flash-attn package built for PyTorch {PYTORCH_VERSION} and CUDA {MAIN_CUDA_VERSION}",
Woosuk Kwon's avatar
Woosuk Kwon committed
301
    url="https://github.com/vllm-project/flash-attention.git",
Tri Dao's avatar
Tri Dao committed
302
303
    classifiers=[
        "Programming Language :: Python :: 3",
304
        "License :: OSI Approved :: BSD License",
Phil Wang's avatar
Phil Wang committed
305
        "Operating System :: Unix",
Tri Dao's avatar
Tri Dao committed
306
    ],
Tri Dao's avatar
Tri Dao committed
307
    ext_modules=ext_modules,
308
    cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
Woosuk Kwon's avatar
Woosuk Kwon committed
309
    python_requires=">=3.8",
Woosuk Kwon's avatar
Woosuk Kwon committed
310
    install_requires=[f"torch == {PYTORCH_VERSION}"],
Woosuk Kwon's avatar
Woosuk Kwon committed
311
312
    setup_requires=["psutil"],
)