setup.py 26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ctypes
5
import importlib.util
6
import json
7
import logging
8
import os
9
import re
10
import shutil
11
import subprocess
bnellnm's avatar
bnellnm committed
12
import sys
13
from pathlib import Path
14
from shutil import which
15

Woosuk Kwon's avatar
Woosuk Kwon committed
16
import torch
17
from packaging.version import Version, parse
18
from setuptools import Extension, setup
19
from setuptools.command.build_ext import build_ext
20
from setuptools_scm import get_version
21
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
22

23
24
25
26
27
28
29
30
31

def load_module_from_path(module_name, path):
    spec = importlib.util.spec_from_file_location(module_name, path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


32
ROOT_DIR = Path(__file__).parent
33
logger = logging.getLogger(__name__)
34
35
36

# cannot import envs directly because it depends on vllm,
#  which is not installed yet
37
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py"))
38
39

VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
40

41
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
42
    logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
43
    VLLM_TARGET_DEVICE = "cpu"
44
elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")):
45
46
    logger.warning(
        "vLLM only supports Linux platform (including WSL) and MacOS."
47
        "Building on %s, "
48
49
50
        "so vLLM may not be able to run correctly",
        sys.platform,
    )
51
    VLLM_TARGET_DEVICE = "empty"
52
53
54
55
56
57
elif (
    sys.platform.startswith("linux")
    and torch.version.cuda is None
    and os.getenv("VLLM_TARGET_DEVICE") is None
    and torch.version.hip is None
):
58
    # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
59
60
    # fallback to cpu
    VLLM_TARGET_DEVICE = "cpu"
61

bnellnm's avatar
bnellnm committed
62
63

def is_sccache_available() -> bool:
64
65
66
    return which("sccache") is not None and not bool(
        int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))
    )
bnellnm's avatar
bnellnm committed
67
68
69
70
71
72
73
74
75
76


def is_ccache_available() -> bool:
    return which("ccache") is not None


def is_ninja_available() -> bool:
    return which("ninja") is not None


77
78
79
80
81
82
83
84
85
86
87
88
def is_url_available(url: str) -> bool:
    from urllib.request import urlopen

    status = None
    try:
        with urlopen(url) as f:
            status = f.status
    except Exception:
        return False
    return status == 200


bnellnm's avatar
bnellnm committed
89
class CMakeExtension(Extension):
90
    def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
91
        super().__init__(name, sources=[], py_limited_api=True, **kwa)
bnellnm's avatar
bnellnm committed
92
93
94
95
96
        self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class cmake_build_ext(build_ext):
    # A dict of extension directories that have been configured.
97
    did_config: dict[str, bool] = {}
bnellnm's avatar
bnellnm committed
98
99
100
101
102

    #
    # Determine number of compilation jobs and optionally nvcc compile threads.
    #
    def compute_num_jobs(self):
103
104
        # `num_jobs` is either the value of the MAX_JOBS environment variable
        # (if defined) or the number of CPUs available.
105
        num_jobs = envs.MAX_JOBS
106
107
        if num_jobs is not None:
            num_jobs = int(num_jobs)
108
            logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
109
110
111
112
113
114
115
        else:
            try:
                # os.sched_getaffinity() isn't universally available, so fall
                #  back to os.cpu_count() if we get an error here.
                num_jobs = len(os.sched_getaffinity(0))
            except AttributeError:
                num_jobs = os.cpu_count()
bnellnm's avatar
bnellnm committed
116

117
        nvcc_threads = None
118
119
120
121
122
        if _is_cuda() and get_nvcc_cuda_version() >= Version("11.2"):
            # `nvcc_threads` is either the value of the NVCC_THREADS
            # environment variable (if defined) or 1.
            # when it is set, we reduce `num_jobs` to avoid
            # overloading the system.
123
            nvcc_threads = envs.NVCC_THREADS
124
125
            if nvcc_threads is not None:
                nvcc_threads = int(nvcc_threads)
126
                logger.info(
127
128
                    "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads
                )
129
130
131
            else:
                nvcc_threads = 1
            num_jobs = max(1, num_jobs // nvcc_threads)
bnellnm's avatar
bnellnm committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

        return num_jobs, nvcc_threads

    #
    # Perform cmake configuration for a single extension.
    #
    def configure(self, ext: CMakeExtension) -> None:
        # If we've already configured using the CMakeLists.txt for
        # this extension, exit early.
        if ext.cmake_lists_dir in cmake_build_ext.did_config:
            return

        cmake_build_ext.did_config[ext.cmake_lists_dir] = True

        # Select the build type.
        # Note: optimization level + debug info are set by the build type
        default_cfg = "Debug" if self.debug else "RelWithDebInfo"
149
        cfg = envs.CMAKE_BUILD_TYPE or default_cfg
bnellnm's avatar
bnellnm committed
150
151

        cmake_args = [
152
153
            "-DCMAKE_BUILD_TYPE={}".format(cfg),
            "-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE),
bnellnm's avatar
bnellnm committed
154
155
        ]

156
        verbose = envs.VERBOSE
bnellnm's avatar
bnellnm committed
157
        if verbose:
158
            cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
bnellnm's avatar
bnellnm committed
159
160
161

        if is_sccache_available():
            cmake_args += [
162
163
164
165
                "-DCMAKE_C_COMPILER_LAUNCHER=sccache",
                "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
                "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache",
                "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
bnellnm's avatar
bnellnm committed
166
167
168
            ]
        elif is_ccache_available():
            cmake_args += [
169
170
171
172
                "-DCMAKE_C_COMPILER_LAUNCHER=ccache",
                "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
                "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache",
                "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
bnellnm's avatar
bnellnm committed
173
174
175
176
            ]

        # Pass the python executable to cmake so it can find an exact
        # match.
177
        cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)]
bnellnm's avatar
bnellnm committed
178

179
180
        # Pass the python path to cmake so it can reuse the build dependencies
        # on subsequent calls to python.
181
        cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))]
182

183
184
185
186
187
188
        # Override the base directory for FetchContent downloads to $ROOT/.deps
        # This allows sharing dependencies between profiles,
        # and plays more nicely with sccache.
        # To override this, set the FETCHCONTENT_BASE_DIR environment variable.
        fc_base_dir = os.path.join(ROOT_DIR, ".deps")
        fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
189
        cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)]
190

bnellnm's avatar
bnellnm committed
191
192
193
194
195
196
        #
        # Setup parallelism and build tool
        #
        num_jobs, nvcc_threads = self.compute_num_jobs()

        if nvcc_threads:
197
            cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)]
bnellnm's avatar
bnellnm committed
198
199

        if is_ninja_available():
200
            build_tool = ["-G", "Ninja"]
bnellnm's avatar
bnellnm committed
201
            cmake_args += [
202
203
                "-DCMAKE_JOB_POOL_COMPILE:STRING=compile",
                "-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs),
bnellnm's avatar
bnellnm committed
204
205
206
207
            ]
        else:
            # Default build tool to whatever cmake picks.
            build_tool = []
208
209
        # Make sure we use the nvcc from CUDA_HOME
        if _is_cuda():
210
            cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"]
211
212
        elif _is_hip():
            cmake_args += [f"-DROCM_PATH={ROCM_HOME}"]
213
214
215
216
217

        other_cmake_args = os.environ.get("CMAKE_ARGS")
        if other_cmake_args:
            cmake_args += other_cmake_args.split()

bnellnm's avatar
bnellnm committed
218
        subprocess.check_call(
219
220
221
            ["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args],
            cwd=self.build_temp,
        )
bnellnm's avatar
bnellnm committed
222
223
224
225

    def build_extensions(self) -> None:
        # Ensure that CMake is present and working
        try:
226
            subprocess.check_output(["cmake", "--version"])
bnellnm's avatar
bnellnm committed
227
        except OSError as e:
228
            raise RuntimeError("Cannot find CMake executable") from e
bnellnm's avatar
bnellnm committed
229
230
231
232
233

        # Create build directory if it does not exist.
        if not os.path.exists(self.build_temp):
            os.makedirs(self.build_temp)

234
        targets = []
235
236
237
238

        def target_name(s: str) -> str:
            return s.removeprefix("vllm.").removeprefix("vllm_flash_attn.")

bnellnm's avatar
bnellnm committed
239
240
241
        # Build all the extensions
        for ext in self.extensions:
            self.configure(ext)
242
            targets.append(target_name(ext.name))
bnellnm's avatar
bnellnm committed
243

244
        num_jobs, _ = self.compute_num_jobs()
bnellnm's avatar
bnellnm committed
245

246
247
248
249
250
251
        build_args = [
            "--build",
            ".",
            f"-j={num_jobs}",
            *[f"--target={name}" for name in targets],
        ]
bnellnm's avatar
bnellnm committed
252

253
        subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
254

255
256
257
258
259
260
261
262
263
264
265
266
        # Install the libraries
        for ext in self.extensions:
            # Install the extension into the proper location
            outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()

            # Skip if the install directory is the same as the build directory
            if outdir == self.build_temp:
                continue

            # CMake appends the extension prefix to the install path,
            # and outdir already contains that prefix, so we need to remove it.
            prefix = outdir
267
            for _ in range(ext.name.count(".")):
268
269
270
271
                prefix = prefix.parent

            # prefix here should actually be the same for all components
            install_args = [
272
273
274
275
276
277
278
                "cmake",
                "--install",
                ".",
                "--prefix",
                prefix,
                "--component",
                target_name(ext.name),
279
280
281
            ]
            subprocess.check_call(install_args, cwd=self.build_temp)

282
283
284
285
    def run(self):
        # First, run the standard build_ext command to compile the extensions
        super().run()

286
        # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
287
288
        # directory so that they can be included in the editable build
        import glob
289
290
291
292
293

        files = glob.glob(
            os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"),
            recursive=True,
        )
294
        for file in files:
295
296
297
            dst_file = os.path.join(
                "vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1]
            )
298
            print(f"Copying {file} to {dst_file}")
299
            os.makedirs(os.path.dirname(dst_file), exist_ok=True)
300
301
            self.copy_file(file, dst_file)

302
303
304
305
306
307
308
309
310
311
312
313
314
315
        if _is_cuda() or _is_hip():
            # copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
            # to current directory so that they can be included in the editable
            # build
            print(
                f"Copying {self.build_lib}/vllm/third_party/triton_kernels "
                "to vllm/third_party/triton_kernels"
            )
            shutil.copytree(
                f"{self.build_lib}/vllm/third_party/triton_kernels",
                "vllm/third_party/triton_kernels",
                dirs_exist_ok=True,
            )

316

317
318
319
320
class precompiled_build_ext(build_ext):
    """Disables extension building when using precompiled binaries."""

    def run(self) -> None:
321
        assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
322
323
324
325
326
327
328

    def build_extensions(self) -> None:
        print("Skipping build_ext: using precompiled extensions.")
        return


class precompiled_wheel_utils:
329
330
    """Extracts libraries and other files from an existing wheel."""

331
332
333
334
335
336
337
338
339
340
341
    @staticmethod
    def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict:
        import tempfile
        import zipfile

        temp_dir = None
        try:
            if not os.path.isfile(wheel_url_or_path):
                wheel_filename = wheel_url_or_path.split("/")[-1]
                temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
                wheel_path = os.path.join(temp_dir, wheel_filename)
342
                print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}")
343
                from urllib.request import urlretrieve
344

345
346
347
348
349
350
351
352
353
354
355
356
                urlretrieve(wheel_url_or_path, filename=wheel_path)
            else:
                wheel_path = wheel_url_or_path
                print(f"Using existing wheel at {wheel_path}")

            package_data_patch = {}

            with zipfile.ZipFile(wheel_path) as wheel:
                files_to_copy = [
                    "vllm/_C.abi3.so",
                    "vllm/_moe_C.abi3.so",
                    "vllm/_flashmla_C.abi3.so",
357
358
                    "vllm/_flashmla_extension_C.abi3.so",
                    "vllm/_sparse_flashmla_C.abi3.so",
359
360
361
362
363
364
                    "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
                    "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
                    "vllm/cumem_allocator.abi3.so",
                ]

                compiled_regex = re.compile(
365
366
                    r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
                )
367
                file_members = list(
368
369
                    filter(lambda x: x.filename in files_to_copy, wheel.filelist)
                )
370
                file_members += list(
371
372
                    filter(lambda x: compiled_regex.match(x.filename), wheel.filelist)
                )
373
374
375
376
377

                for file in file_members:
                    print(f"[extract] {file.filename}")
                    target_path = os.path.join(".", file.filename)
                    os.makedirs(os.path.dirname(target_path), exist_ok=True)
378
379
380
381
                    with (
                        wheel.open(file.filename) as src,
                        open(target_path, "wb") as dst,
                    ):
382
383
384
385
                        shutil.copyfileobj(src, dst)

                    pkg = os.path.dirname(file.filename).replace("/", ".")
                    package_data_patch.setdefault(pkg, []).append(
386
387
                        os.path.basename(file.filename)
                    )
388
389
390
391
392
393
394
395
396

            return package_data_patch
        finally:
            if temp_dir is not None:
                print(f"Removing temporary directory {temp_dir}")
                shutil.rmtree(temp_dir)

    @staticmethod
    def get_base_commit_in_main_branch() -> str:
397
398
399
        # Force to use the nightly wheel. This is mainly used for CI testing.
        if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
            return "nightly"
400
401

        try:
402
            # Get the latest commit hash of the upstream main branch.
403
404
405
406
407
408
409
            resp_json = subprocess.check_output(
                [
                    "curl",
                    "-s",
                    "https://api.github.com/repos/vllm-project/vllm/commits/main",
                ]
            ).decode("utf-8")
410
411
            upstream_main_commit = json.loads(resp_json)["sha"]

412
413
414
415
            # In Docker build context, .git may be immutable or missing.
            if envs.VLLM_DOCKER_BUILD_CONTEXT:
                return upstream_main_commit

416
417
418
            # Check if the upstream_main_commit exists in the local repo
            try:
                subprocess.check_output(
419
420
                    ["git", "cat-file", "-e", f"{upstream_main_commit}"]
                )
421
422
423
424
425
            except subprocess.CalledProcessError:
                # If not present, fetch it from the remote repository.
                # Note that this does not update any local branches,
                # but ensures that this commit ref and its history are
                # available in our local repo.
426
427
428
                subprocess.check_call(
                    ["git", "fetch", "https://github.com/vllm-project/vllm", "main"]
                )
429
430
431

            # Then get the commit hash of the current branch that is the same as
            # the upstream main commit.
432
433
434
435
436
            current_branch = (
                subprocess.check_output(["git", "branch", "--show-current"])
                .decode("utf-8")
                .strip()
            )
437

438
439
440
441
442
443
444
            base_commit = (
                subprocess.check_output(
                    ["git", "merge-base", f"{upstream_main_commit}", current_branch]
                )
                .decode("utf-8")
                .strip()
            )
445
            return base_commit
446
447
        except ValueError as err:
            raise ValueError(err) from None
448
449
450
451
        except Exception as err:
            logger.warning(
                "Failed to get the base commit in the main branch. "
                "Using the nightly wheel. The libraries in this "
452
453
454
                "wheel may not be compatible with your dev branch: %s",
                err,
            )
455
            return "nightly"
456
457


458
459
460
461
def _no_device() -> bool:
    return VLLM_TARGET_DEVICE == "empty"


462
def _is_cuda() -> bool:
463
    has_cuda = torch.version.cuda is not None
464
    return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()
465
466


467
def _is_hip() -> bool:
468
469
470
    return (
        VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm"
    ) and torch.version.hip is not None
471
472


473
474
475
476
def _is_tpu() -> bool:
    return VLLM_TARGET_DEVICE == "tpu"


477
478
479
480
def _is_cpu() -> bool:
    return VLLM_TARGET_DEVICE == "cpu"


481
482
483
484
def _is_xpu() -> bool:
    return VLLM_TARGET_DEVICE == "xpu"


485
486
487
488
def _build_custom_ops() -> bool:
    return _is_cuda() or _is_hip() or _is_cpu()


489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def get_rocm_version():
    # Get the Rocm version from the ROCM_HOME/bin/librocm-core.so
    # see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21
    try:
        librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so"
        if not librocm_core_file.is_file():
            return None
        librocm_core = ctypes.CDLL(librocm_core_file)
        VerErrors = ctypes.c_uint32
        get_rocm_core_version = librocm_core.getROCmVersion
        get_rocm_core_version.restype = VerErrors
        get_rocm_core_version.argtypes = [
            ctypes.POINTER(ctypes.c_uint32),
            ctypes.POINTER(ctypes.c_uint32),
            ctypes.POINTER(ctypes.c_uint32),
        ]
        major = ctypes.c_uint32()
        minor = ctypes.c_uint32()
        patch = ctypes.c_uint32()
508

509
510
511
512
513
514
        if (
            get_rocm_core_version(
                ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)
            )
            == 0
        ):
515
            return f"{major.value}.{minor.value}.{patch.value}"
516
        return None
517
    except Exception:
518
        return None
Woosuk Kwon's avatar
Woosuk Kwon committed
519

520

bnellnm's avatar
bnellnm committed
521
def get_nvcc_cuda_version() -> Version:
522
523
524
525
    """Get the CUDA version from nvcc.

    Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
    """
526
    assert CUDA_HOME is not None, "CUDA_HOME is not set"
527
528
529
    nvcc_output = subprocess.check_output(
        [CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True
    )
530
531
532
533
534
535
    output = nvcc_output.split()
    release_idx = output.index("release") + 1
    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
    return nvcc_cuda_version


536
537
538
539
540
def get_gaudi_sw_version():
    """
    Returns the driver version.
    """
    # Enable console printing for `hl-smi` check
541
542
543
544
545
546
547
    output = subprocess.run(
        "hl-smi",
        shell=True,
        text=True,
        capture_output=True,
        env={"ENABLE_CONSOLE": "true"},
    )
548
    if output.returncode == 0 and output.stdout:
549
550
551
552
553
554
        return (
            output.stdout.split("\n")[2]
            .replace(" ", "")
            .split(":")[1][:-1]
            .split("-")[0]
        )
555
556
557
    return "0.0.0"  # when hl-smi is not available


558
def get_vllm_version() -> str:
559
560
561
    # Allow overriding the version. This is useful to build platform-specific
    # wheels (e.g. CPU, TPU) without modifying the source.
    if env_version := os.getenv("VLLM_VERSION_OVERRIDE"):
562
563
564
        print(f"Overriding VLLM version with {env_version} from VLLM_VERSION_OVERRIDE")
        os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = env_version
        return get_version(write_to="vllm/_version.py")
565

566
    version = get_version(write_to="vllm/_version.py")
567
    sep = "+" if "+" not in version else "."  # dev versions might contain +
568

569
    if _no_device():
570
        if envs.VLLM_TARGET_DEVICE == "empty":
571
            version += f"{sep}empty"
572
    elif _is_cuda():
573
        if envs.VLLM_USE_PRECOMPILED:
574
            version += f"{sep}precompiled"
575
576
        else:
            cuda_version = str(get_nvcc_cuda_version())
577
            if cuda_version != envs.VLLM_MAIN_CUDA_VERSION:
578
579
580
581
                cuda_version_str = cuda_version.replace(".", "")[:3]
                # skip this for source tarball, required for pypi
                if "sdist" not in sys.argv:
                    version += f"{sep}cu{cuda_version_str}"
582
    elif _is_hip():
583
584
        # Get the Rocm Version
        rocm_version = get_rocm_version() or torch.version.hip
585
        if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION:
586
            version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
587
    elif _is_tpu():
588
        version += f"{sep}tpu"
589
    elif _is_cpu():
590
591
        if envs.VLLM_TARGET_DEVICE == "cpu":
            version += f"{sep}cpu"
592
    elif _is_xpu():
593
        version += f"{sep}xpu"
594
    else:
595
        raise RuntimeError("Unknown runtime environment")
596

597
598
599
    return version


600
def get_requirements() -> list[str]:
601
    """Get Python package dependencies from requirements.txt."""
602
    requirements_dir = ROOT_DIR / "requirements"
603

604
    def _read_requirements(filename: str) -> list[str]:
605
        with open(requirements_dir / filename) as f:
606
            requirements = f.read().strip().split("\n")
607
608
609
610
        resolved_requirements = []
        for line in requirements:
            if line.startswith("-r "):
                resolved_requirements += _read_requirements(line.split()[1])
611
612
613
614
615
            elif (
                not line.startswith("--")
                and not line.startswith("#")
                and line.strip() != ""
            ):
616
617
618
                resolved_requirements.append(line)
        return resolved_requirements

619
    if _no_device():
620
        requirements = _read_requirements("common.txt")
621
    elif _is_cuda():
622
        requirements = _read_requirements("cuda.txt")
623
        cuda_major, cuda_minor = torch.version.cuda.split(".")
624
625
        modified_requirements = []
        for req in requirements:
626
            if "vllm-flash-attn" in req and cuda_major != "12":
627
                # vllm-flash-attn is built only for CUDA 12.x.
628
629
630
                # Skip for other versions.
                continue
            modified_requirements.append(req)
631
        requirements = modified_requirements
632
    elif _is_hip():
633
        requirements = _read_requirements("rocm.txt")
634
    elif _is_tpu():
635
        requirements = _read_requirements("tpu.txt")
636
    elif _is_cpu():
637
        requirements = _read_requirements("cpu.txt")
638
    elif _is_xpu():
639
        requirements = _read_requirements("xpu.txt")
640
    else:
641
        raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.")
642
643
644
    return requirements


bnellnm's avatar
bnellnm committed
645
646
ext_modules = []

647
if _is_cuda() or _is_hip():
bnellnm's avatar
bnellnm committed
648
    ext_modules.append(CMakeExtension(name="vllm._moe_C"))
649
    ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
650
651
652
    # Optional since this doesn't get built (produce an .so file). This is just
    # copying the relevant .py files from the source repository.
    ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
bnellnm's avatar
bnellnm committed
653

654
655
656
if _is_hip():
    ext_modules.append(CMakeExtension(name="vllm._rocm_C"))

657
if _is_cuda():
658
    ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
659
660
    if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
        # FA3 requires CUDA 12.3 or later
661
        ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
662
663
        # Optional since this doesn't get built (produce an .so file) when
        # not targeting a hopper system
664
        ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
665
        ext_modules.append(
666
667
            CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
        )
668

669
if _build_custom_ops():
bnellnm's avatar
bnellnm committed
670
671
    ext_modules.append(CMakeExtension(name="vllm._C"))

672
package_data = {
673
674
675
676
677
    "vllm": [
        "py.typed",
        "model_executor/layers/fused_moe/configs/*.json",
        "model_executor/layers/quantization/utils/configs/*.json",
    ]
678
}
Simon Mo's avatar
Simon Mo committed
679

680
681
682
683
684
685
686
# If using precompiled, extract and patch package_data (in advance of setup)
if envs.VLLM_USE_PRECOMPILED:
    assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
    wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
    if wheel_location is not None:
        wheel_url = wheel_location
    else:
687
        import platform
688

689
690
691
692
693
694
695
        arch = platform.machine()
        if arch == "x86_64":
            wheel_tag = "manylinux1_x86_64"
        elif arch == "aarch64":
            wheel_tag = "manylinux2014_aarch64"
        else:
            raise ValueError(f"Unsupported architecture: {arch}")
696
        base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
697
        wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
698
699
700
        nightly_wheel_url = (
            f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
        )
701
        from urllib.request import urlopen
702

703
704
705
        try:
            with urlopen(wheel_url) as resp:
                if resp.status != 200:
706
                    wheel_url = nightly_wheel_url
707
708
        except Exception as e:
            print(f"[warn] Falling back to nightly wheel: {e}")
709
            wheel_url = nightly_wheel_url
710

711
    patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url)
712
713
714
    for pkg, files in patch.items():
        package_data.setdefault(pkg, []).extend(files)

715
716
717
if _no_device():
    ext_modules = []

718
if not ext_modules:
719
720
    cmdclass = {}
else:
721
    cmdclass = {
722
723
724
        "build_ext": precompiled_build_ext
        if envs.VLLM_USE_PRECOMPILED
        else cmake_build_ext
725
    }
726

bnellnm's avatar
bnellnm committed
727
setup(
728
    # static metadata should rather go in pyproject.toml
729
    version=get_vllm_version(),
Woosuk Kwon's avatar
Woosuk Kwon committed
730
    ext_modules=ext_modules,
731
    install_requires=get_requirements(),
732
    extras_require={
733
        "bench": ["pandas", "matplotlib", "seaborn", "datasets"],
734
        "tensorizer": ["tensorizer==2.10.1"],
735
        "fastsafetensors": ["fastsafetensors >= 0.1.10"],
736
        "runai": ["runai-model-streamer[s3,gcs] >= 0.15.0"],
737
738
739
740
741
        "audio": [
            "librosa",
            "soundfile",
            "mistral_common[audio]",
        ],  # Required for audio processing
742
        "video": [],  # Kept for backwards compatibility
743
        "flashinfer": [],  # Kept for backwards compatibility
744
745
        # Optional deps for AMD FP4 quantization support
        "petit-kernel": ["petit-kernel"],
746
    },
747
    cmdclass=cmdclass,
Simon Mo's avatar
Simon Mo committed
748
    package_data=package_data,
Woosuk Kwon's avatar
Woosuk Kwon committed
749
)