setup.py 19.3 KB
Newer Older
1
import importlib.util
2
import logging
3
4
import os
import re
5
import subprocess
bnellnm's avatar
bnellnm committed
6
import sys
7
from pathlib import Path
8
from shutil import which
9
from typing import Dict, List
10

Woosuk Kwon's avatar
Woosuk Kwon committed
11
import torch
12
13
14
from packaging.version import Version, parse
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
15
from setuptools_scm import get_version
bnellnm's avatar
bnellnm committed
16
from torch.utils.cpp_extension import CUDA_HOME
17

18
19
20
21
22
23
24
25
26

def load_module_from_path(module_name, path):
    spec = importlib.util.spec_from_file_location(module_name, path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


27
ROOT_DIR = os.path.dirname(__file__)
28
logger = logging.getLogger(__name__)
29
30
31
32
33
34

# cannot import envs directly because it depends on vllm,
#  which is not installed yet
envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py'))

VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
35

36
37
38
39
40
41
if not sys.platform.startswith("linux"):
    logger.warning(
        "vLLM only supports Linux platform (including WSL). "
        "Building on %s, "
        "so vLLM may not be able to run correctly", sys.platform)
    VLLM_TARGET_DEVICE = "empty"
42

43
44
MAIN_CUDA_VERSION = "12.1"

bnellnm's avatar
bnellnm committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

def is_sccache_available() -> bool:
    return which("sccache") is not None


def is_ccache_available() -> bool:
    return which("ccache") is not None


def is_ninja_available() -> bool:
    return which("ninja") is not None


def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix):]
    return text


class CMakeExtension(Extension):

    def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
67
        super().__init__(name, sources=[], py_limited_api=True, **kwa)
bnellnm's avatar
bnellnm committed
68
69
70
71
72
        self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class cmake_build_ext(build_ext):
    # A dict of extension directories that have been configured.
73
    did_config: Dict[str, bool] = {}
bnellnm's avatar
bnellnm committed
74
75
76
77
78

    #
    # Determine number of compilation jobs and optionally nvcc compile threads.
    #
    def compute_num_jobs(self):
79
80
        # `num_jobs` is either the value of the MAX_JOBS environment variable
        # (if defined) or the number of CPUs available.
81
        num_jobs = envs.MAX_JOBS
82
83
        if num_jobs is not None:
            num_jobs = int(num_jobs)
84
            logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
85
86
87
88
89
90
91
        else:
            try:
                # os.sched_getaffinity() isn't universally available, so fall
                #  back to os.cpu_count() if we get an error here.
                num_jobs = len(os.sched_getaffinity(0))
            except AttributeError:
                num_jobs = os.cpu_count()
bnellnm's avatar
bnellnm committed
92

93
        nvcc_threads = None
94
95
96
97
98
        if _is_cuda() and get_nvcc_cuda_version() >= Version("11.2"):
            # `nvcc_threads` is either the value of the NVCC_THREADS
            # environment variable (if defined) or 1.
            # when it is set, we reduce `num_jobs` to avoid
            # overloading the system.
99
            nvcc_threads = envs.NVCC_THREADS
100
101
            if nvcc_threads is not None:
                nvcc_threads = int(nvcc_threads)
102
103
104
                logger.info(
                    "Using NVCC_THREADS=%d as the number of nvcc threads.",
                    nvcc_threads)
105
106
107
            else:
                nvcc_threads = 1
            num_jobs = max(1, num_jobs // nvcc_threads)
bnellnm's avatar
bnellnm committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        return num_jobs, nvcc_threads

    #
    # Perform cmake configuration for a single extension.
    #
    def configure(self, ext: CMakeExtension) -> None:
        # If we've already configured using the CMakeLists.txt for
        # this extension, exit early.
        if ext.cmake_lists_dir in cmake_build_ext.did_config:
            return

        cmake_build_ext.did_config[ext.cmake_lists_dir] = True

        # Select the build type.
        # Note: optimization level + debug info are set by the build type
        default_cfg = "Debug" if self.debug else "RelWithDebInfo"
125
        cfg = envs.CMAKE_BUILD_TYPE or default_cfg
bnellnm's avatar
bnellnm committed
126
127
128

        cmake_args = [
            '-DCMAKE_BUILD_TYPE={}'.format(cfg),
129
            '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
bnellnm's avatar
bnellnm committed
130
131
        ]

132
        verbose = envs.VERBOSE
bnellnm's avatar
bnellnm committed
133
134
135
136
137
        if verbose:
            cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON']

        if is_sccache_available():
            cmake_args += [
138
                '-DCMAKE_C_COMPILER_LAUNCHER=sccache',
bnellnm's avatar
bnellnm committed
139
140
                '-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
                '-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
141
                '-DCMAKE_HIP_COMPILER_LAUNCHER=sccache',
bnellnm's avatar
bnellnm committed
142
143
144
            ]
        elif is_ccache_available():
            cmake_args += [
145
                '-DCMAKE_C_COMPILER_LAUNCHER=ccache',
bnellnm's avatar
bnellnm committed
146
147
                '-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
                '-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
148
                '-DCMAKE_HIP_COMPILER_LAUNCHER=ccache',
bnellnm's avatar
bnellnm committed
149
150
151
152
153
154
            ]

        # Pass the python executable to cmake so it can find an exact
        # match.
        cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]

155
156
157
158
        # Pass the python path to cmake so it can reuse the build dependencies
        # on subsequent calls to python.
        cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))]

159
160
161
162
163
164
165
166
        # Override the base directory for FetchContent downloads to $ROOT/.deps
        # This allows sharing dependencies between profiles,
        # and plays more nicely with sccache.
        # To override this, set the FETCHCONTENT_BASE_DIR environment variable.
        fc_base_dir = os.path.join(ROOT_DIR, ".deps")
        fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
        cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)]

bnellnm's avatar
bnellnm committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        #
        # Setup parallelism and build tool
        #
        num_jobs, nvcc_threads = self.compute_num_jobs()

        if nvcc_threads:
            cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)]

        if is_ninja_available():
            build_tool = ['-G', 'Ninja']
            cmake_args += [
                '-DCMAKE_JOB_POOL_COMPILE:STRING=compile',
                '-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs),
            ]
        else:
            # Default build tool to whatever cmake picks.
            build_tool = []
        subprocess.check_call(
            ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
            cwd=self.build_temp)

    def build_extensions(self) -> None:
        # Ensure that CMake is present and working
        try:
            subprocess.check_output(['cmake', '--version'])
        except OSError as e:
            raise RuntimeError('Cannot find CMake executable') from e

        # Create build directory if it does not exist.
        if not os.path.exists(self.build_temp):
            os.makedirs(self.build_temp)

199
        targets = []
200
201
        target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
                                              "vllm_flash_attn.")
bnellnm's avatar
bnellnm committed
202
203
204
        # Build all the extensions
        for ext in self.extensions:
            self.configure(ext)
205
            targets.append(target_name(ext.name))
bnellnm's avatar
bnellnm committed
206

207
        num_jobs, _ = self.compute_num_jobs()
bnellnm's avatar
bnellnm committed
208

209
210
211
212
213
214
        build_args = [
            "--build",
            ".",
            f"-j={num_jobs}",
            *[f"--target={name}" for name in targets],
        ]
bnellnm's avatar
bnellnm committed
215

216
        subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # Install the libraries
        for ext in self.extensions:
            # Install the extension into the proper location
            outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()

            # Skip if the install directory is the same as the build directory
            if outdir == self.build_temp:
                continue

            # CMake appends the extension prefix to the install path,
            # and outdir already contains that prefix, so we need to remove it.
            prefix = outdir
            for i in range(ext.name.count('.')):
                prefix = prefix.parent

            # prefix here should actually be the same for all components
            install_args = [
                "cmake", "--install", ".", "--prefix", prefix, "--component",
                target_name(ext.name)
            ]
            subprocess.check_call(install_args, cwd=self.build_temp)

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def run(self):
        # First, run the standard build_ext command to compile the extensions
        super().run()

        # copy vllm/vllm_flash_attn/*.py from self.build_lib to current
        # directory so that they can be included in the editable build
        import glob
        files = glob.glob(
            os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py"))
        for file in files:
            dst_file = os.path.join("vllm/vllm_flash_attn",
                                    os.path.basename(file))
            print(f"Copying {file} to {dst_file}")
            self.copy_file(file, dst_file)

255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def _is_hpu() -> bool:
    is_hpu_available = True
    try:
        subprocess.run(["hl-smi"], capture_output=True, check=True)
    except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
        if not os.path.exists('/dev/accel/accel0') and not os.path.exists(
                '/dev/accel/accel_controlD0'):
            # last resort...
            try:
                output = subprocess.check_output(
                    'lsmod | grep habanalabs | wc -l', shell=True)
                is_hpu_available = int(output) > 0
            except (ValueError, FileNotFoundError, PermissionError,
                    subprocess.CalledProcessError):
                is_hpu_available = False
    return is_hpu_available or VLLM_TARGET_DEVICE == "hpu"


274
275
276
277
def _no_device() -> bool:
    return VLLM_TARGET_DEVICE == "empty"


278
def _is_cuda() -> bool:
279
280
    has_cuda = torch.version.cuda is not None
    return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
281
            and not (_is_neuron() or _is_tpu() or _is_hpu()))
282
283


284
def _is_hip() -> bool:
285
286
    return (VLLM_TARGET_DEVICE == "cuda"
            or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
287
288


289
290
291
292
def _is_neuron() -> bool:
    torch_neuronx_installed = True
    try:
        subprocess.run(["neuron-ls"], capture_output=True, check=True)
293
    except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
294
        torch_neuronx_installed = False
295
    return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron"
296
297


298
299
300
301
def _is_tpu() -> bool:
    return VLLM_TARGET_DEVICE == "tpu"


302
303
304
305
def _is_cpu() -> bool:
    return VLLM_TARGET_DEVICE == "cpu"


306
307
308
309
def _is_openvino() -> bool:
    return VLLM_TARGET_DEVICE == "openvino"


310
311
312
313
def _is_xpu() -> bool:
    return VLLM_TARGET_DEVICE == "xpu"


314
315
316
317
def _build_custom_ops() -> bool:
    return _is_cuda() or _is_hip() or _is_cpu()


318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def get_hipcc_rocm_version():
    # Run the hipcc --version command
    result = subprocess.run(['hipcc', '--version'],
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT,
                            text=True)

    # Check if the command was executed successfully
    if result.returncode != 0:
        print("Error running 'hipcc --version'")
        return None

    # Extract the version using a regular expression
    match = re.search(r'HIP version: (\S+)', result.stdout)
    if match:
        # Return the version string
        return match.group(1)
    else:
        print("Could not find HIP version in the output")
        return None
Woosuk Kwon's avatar
Woosuk Kwon committed
338

339

340
341
342
def get_neuronxcc_version():
    import sysconfig
    site_dir = sysconfig.get_paths()["purelib"]
343
344
    version_file = os.path.join(site_dir, "neuronxcc", "version",
                                "__init__.py")
345
346

    # Check if the command was executed successfully
347
    with open(version_file) as fp:
348
349
350
351
352
353
354
355
        content = fp.read()

    # Extract the version using a regular expression
    match = re.search(r"__version__ = '(\S+)'", content)
    if match:
        # Return the version string
        return match.group(1)
    else:
356
        raise RuntimeError("Could not find Neuron version in the output")
357
358


bnellnm's avatar
bnellnm committed
359
def get_nvcc_cuda_version() -> Version:
360
361
362
363
    """Get the CUDA version from nvcc.

    Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
    """
364
    assert CUDA_HOME is not None, "CUDA_HOME is not set"
bnellnm's avatar
bnellnm committed
365
    nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
366
367
368
369
370
371
372
                                          universal_newlines=True)
    output = nvcc_output.split()
    release_idx = output.index("release") + 1
    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
    return nvcc_cuda_version


373
374
375
376
def get_path(*filepath) -> str:
    return os.path.join(ROOT_DIR, *filepath)


377
378
379
380
381
382
383
384
def get_gaudi_sw_version():
    """
    Returns the driver version.
    """
    # Enable console printing for `hl-smi` check
    output = subprocess.run("hl-smi",
                            shell=True,
                            text=True,
385
                            capture_output=True,
386
387
388
389
390
391
392
                            env={"ENABLE_CONSOLE": "true"})
    if output.returncode == 0 and output.stdout:
        return output.stdout.split("\n")[2].replace(
            " ", "").split(":")[1][:-1].split("-")[0]
    return "0.0.0"  # when hl-smi is not available


393
def get_vllm_version() -> str:
394
395
396
397
    version = get_version(
        write_to="vllm/_version.py",  # TODO: move this to pyproject.toml
    )

398
    sep = "+" if "+" not in version else "."  # dev versions might contain +
399

400
    if _no_device():
401
        if envs.VLLM_TARGET_DEVICE == "empty":
402
            version += f"{sep}empty"
403
    elif _is_cuda():
bnellnm's avatar
bnellnm committed
404
        cuda_version = str(get_nvcc_cuda_version())
405
406
        if cuda_version != MAIN_CUDA_VERSION:
            cuda_version_str = cuda_version.replace(".", "")[:3]
407
408
            # skip this for source tarball, required for pypi
            if "sdist" not in sys.argv:
409
                version += f"{sep}cu{cuda_version_str}"
410
    elif _is_hip():
411
412
413
414
        # Get the HIP version
        hipcc_version = get_hipcc_rocm_version()
        if hipcc_version != MAIN_CUDA_VERSION:
            rocm_version_str = hipcc_version.replace(".", "")[:3]
415
            version += f"{sep}rocm{rocm_version_str}"
416
417
    elif _is_neuron():
        # Get the Neuron version
bnellnm's avatar
bnellnm committed
418
        neuron_version = str(get_neuronxcc_version())
419
420
        if neuron_version != MAIN_CUDA_VERSION:
            neuron_version_str = neuron_version.replace(".", "")[:3]
421
            version += f"{sep}neuron{neuron_version_str}"
422
423
424
425
426
427
    elif _is_hpu():
        # Get the Intel Gaudi Software Suite version
        gaudi_sw_version = str(get_gaudi_sw_version())
        if gaudi_sw_version != MAIN_CUDA_VERSION:
            gaudi_sw_version = gaudi_sw_version.replace(".", "")[:3]
            version += f"{sep}gaudi{gaudi_sw_version}"
428
    elif _is_openvino():
429
        version += f"{sep}openvino"
430
    elif _is_tpu():
431
        version += f"{sep}tpu"
432
    elif _is_cpu():
433
        version += f"{sep}cpu"
434
    elif _is_xpu():
435
        version += f"{sep}xpu"
436
    else:
437
        raise RuntimeError("Unknown runtime environment")
438

439
440
441
    return version


442
def read_readme() -> str:
Stephen Krider's avatar
Stephen Krider committed
443
444
445
    """Read the README file if present."""
    p = get_path("README.md")
    if os.path.isfile(p):
446
447
        with open(get_path("README.md"), encoding="utf-8") as f:
            return f.read()
Stephen Krider's avatar
Stephen Krider committed
448
449
    else:
        return ""
450
451


452
453
def get_requirements() -> List[str]:
    """Get Python package dependencies from requirements.txt."""
454
455
456

    def _read_requirements(filename: str) -> List[str]:
        with open(get_path(filename)) as f:
457
            requirements = f.read().strip().split("\n")
458
459
460
461
        resolved_requirements = []
        for line in requirements:
            if line.startswith("-r "):
                resolved_requirements += _read_requirements(line.split()[1])
462
463
            elif line.startswith("--"):
                continue
464
465
466
467
            else:
                resolved_requirements.append(line)
        return resolved_requirements

468
469
470
    if _no_device():
        requirements = _read_requirements("requirements-cuda.txt")
    elif _is_cuda():
471
        requirements = _read_requirements("requirements-cuda.txt")
472
        cuda_major, cuda_minor = torch.version.cuda.split(".")
473
474
        modified_requirements = []
        for req in requirements:
475
476
            if ("vllm-flash-attn" in req
                    and not (cuda_major == "12" and cuda_minor == "1")):
477
478
479
480
                # vllm-flash-attn is built only for CUDA 12.1.
                # Skip for other versions.
                continue
            modified_requirements.append(req)
481
        requirements = modified_requirements
482
    elif _is_hip():
483
        requirements = _read_requirements("requirements-rocm.txt")
484
    elif _is_neuron():
485
        requirements = _read_requirements("requirements-neuron.txt")
486
487
    elif _is_hpu():
        requirements = _read_requirements("requirements-hpu.txt")
488
489
    elif _is_openvino():
        requirements = _read_requirements("requirements-openvino.txt")
490
491
    elif _is_tpu():
        requirements = _read_requirements("requirements-tpu.txt")
492
    elif _is_cpu():
493
        requirements = _read_requirements("requirements-cpu.txt")
494
495
    elif _is_xpu():
        requirements = _read_requirements("requirements-xpu.txt")
496
497
    else:
        raise ValueError(
498
            "Unsupported platform, please use CUDA, ROCm, Neuron, HPU, "
499
            "OpenVINO, or CPU.")
500
501
502
    return requirements


bnellnm's avatar
bnellnm committed
503
504
ext_modules = []

505
if _is_cuda() or _is_hip():
bnellnm's avatar
bnellnm committed
506
507
    ext_modules.append(CMakeExtension(name="vllm._moe_C"))

508
509
510
if _is_hip():
    ext_modules.append(CMakeExtension(name="vllm._rocm_C"))

511
512
513
514
if _is_cuda():
    ext_modules.append(
        CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))

515
if _build_custom_ops():
bnellnm's avatar
bnellnm committed
516
517
    ext_modules.append(CMakeExtension(name="vllm._C"))

518
519
520
package_data = {
    "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
}
521
if envs.VLLM_USE_PRECOMPILED:
522
    ext_modules = []
Simon Mo's avatar
Simon Mo committed
523
524
    package_data["vllm"].append("*.so")

525
526
527
if _no_device():
    ext_modules = []

bnellnm's avatar
bnellnm committed
528
setup(
Woosuk Kwon's avatar
Woosuk Kwon committed
529
    name="vllm",
530
    version=get_vllm_version(),
Woosuk Kwon's avatar
Woosuk Kwon committed
531
    author="vLLM Team",
532
    license="Apache 2.0",
Woosuk Kwon's avatar
Woosuk Kwon committed
533
534
    description=("A high-throughput and memory-efficient inference and "
                 "serving engine for LLMs"),
535
536
    long_description=read_readme(),
    long_description_content_type="text/markdown",
537
    url="https://github.com/vllm-project/vllm",
538
    project_urls={
539
540
        "Homepage": "https://github.com/vllm-project/vllm",
        "Documentation": "https://vllm.readthedocs.io/en/latest/",
541
542
543
544
    },
    classifiers=[
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
Woosuk Kwon's avatar
Woosuk Kwon committed
545
        "Programming Language :: Python :: 3.11",
546
        "Programming Language :: Python :: 3.12",
547
        "License :: OSI Approved :: Apache Software License",
548
549
550
        "Intended Audience :: Developers",
        "Intended Audience :: Information Technology",
        "Intended Audience :: Science/Research",
551
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
552
        "Topic :: Scientific/Engineering :: Information Analysis",
553
    ],
bnellnm's avatar
bnellnm committed
554
    packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples",
555
                                    "tests*")),
556
    python_requires=">=3.9",
557
    install_requires=get_requirements(),
Woosuk Kwon's avatar
Woosuk Kwon committed
558
    ext_modules=ext_modules,
559
    extras_require={
560
        "tensorizer": ["tensorizer>=2.9.0"],
561
        "audio": ["librosa", "soundfile"]  # Required for audio processing
562
    },
563
    cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
Simon Mo's avatar
Simon Mo committed
564
    package_data=package_data,
Ethan Xu's avatar
Ethan Xu committed
565
566
567
568
569
    entry_points={
        "console_scripts": [
            "vllm=vllm.scripts:main",
        ],
    },
Woosuk Kwon's avatar
Woosuk Kwon committed
570
)