setup.py 17.6 KB
Newer Older
1
import contextlib
2
3
4
import io
import os
import re
5
import shutil
6
import subprocess
7
import warnings
8
9
from pathlib import Path
from typing import List, Set
10

11
from packaging.version import parse, Version
Woosuk Kwon's avatar
Woosuk Kwon committed
12
import setuptools
13
import sys
Woosuk Kwon's avatar
Woosuk Kwon committed
14
import torch
15
import torch.utils.cpp_extension as torch_cpp_ext
16
17
18
19
20
21
from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
    CUDA_HOME,
    ROCM_HOME,
)
22
23

ROOT_DIR = os.path.dirname(__file__)
24

25
26
27
28
# vLLM only supports Linux platform
assert sys.platform.startswith(
    "linux"), "vLLM only supports Linux platform (including WSL)."

29
30
31
32
33
# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
# The downside is that this method is deprecated, see
# https://github.com/pypa/setuptools/issues/917

34
35
MAIN_CUDA_VERSION = "12.1"

36
# Supported NVIDIA GPU architectures.
37
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
38
ROCM_SUPPORTED_ARCHS = {"gfx908", "gfx90a", "gfx942", "gfx1100"}
39
40
41
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


42
43
44
45
def _is_cuda() -> bool:
    return torch.version.cuda is not None


46
47
48
49
def _is_hip() -> bool:
    return torch.version.hip is not None


50
51
52
53
def _is_neuron() -> bool:
    torch_neuronx_installed = True
    try:
        subprocess.run(["neuron-ls"], capture_output=True, check=True)
54
    except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
55
56
57
58
        torch_neuronx_installed = False
    return torch_neuronx_installed


59
# Compiler flags.
60
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
61
# TODO(woosuk): Should we use -O3?
62
NVCC_FLAGS = ["-O2", "-std=c++17"]
Woosuk Kwon's avatar
Woosuk Kwon committed
63

64
65
if _is_hip():
    if ROCM_HOME is None:
66
67
        raise RuntimeError("Cannot find ROCM_HOME. "
                           "ROCm must be available to build the package.")
68
    NVCC_FLAGS += ["-DUSE_ROCM"]
69
70
    NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"]
    NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"]
71
72
73
74
75

if _is_cuda() and CUDA_HOME is None:
    raise RuntimeError(
        "Cannot find CUDA_HOME. CUDA must be available to build the package.")

76
77
78
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

def get_hipcc_rocm_version():
    # Run the hipcc --version command
    result = subprocess.run(['hipcc', '--version'],
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT,
                            text=True)

    # Check if the command was executed successfully
    if result.returncode != 0:
        print("Error running 'hipcc --version'")
        return None

    # Extract the version using a regular expression
    match = re.search(r'HIP version: (\S+)', result.stdout)
    if match:
        # Return the version string
        return match.group(1)
    else:
        print("Could not find HIP version in the output")
        return None
Woosuk Kwon's avatar
Woosuk Kwon committed
101

102

103
104
105
106
107
def glob(pattern: str):
    root = Path(__name__).parent
    return [str(p) for p in root.glob(pattern)]


108
109
110
def get_neuronxcc_version():
    import sysconfig
    site_dir = sysconfig.get_paths()["purelib"]
111
112
    version_file = os.path.join(site_dir, "neuronxcc", "version",
                                "__init__.py")
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    # Check if the command was executed successfully
    with open(version_file, "rt") as fp:
        content = fp.read()

    # Extract the version using a regular expression
    match = re.search(r"__version__ = '(\S+)'", content)
    if match:
        # Return the version string
        return match.group(1)
    else:
        raise RuntimeError("Could not find HIP version in the output")


127
128
129
130
131
132
133
134
135
136
137
138
139
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
    """Get the CUDA version from nvcc.

    Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
    """
    nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
                                          universal_newlines=True)
    output = nvcc_output.split()
    release_idx = output.index("release") + 1
    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
    return nvcc_cuda_version


140
141
142
143
144
145
146
147
148
149
150
151
def get_pytorch_rocm_arch() -> Set[str]:
    """Get the cross section of Pytorch,and vllm supported gfx arches

    ROCM can get the supported gfx architectures in one of two ways
    Either through the PYTORCH_ROCM_ARCH env var, or output from
    rocm_agent_enumerator.

    In either case we can generate a list of supported arch's and
    cross reference with VLLM's own ROCM_SUPPORTED_ARCHs.
    """
    env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)

152
153
    # If we don't have PYTORCH_ROCM_ARCH specified pull the list from
    # rocm_agent_enumerator
154
155
    if env_arch_list is None:
        command = "rocm_agent_enumerator"
156
157
        env_arch_list = (subprocess.check_output(
            [command]).decode('utf-8').strip().replace("\n", ";"))
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        arch_source_str = "rocm_agent_enumerator"
    else:
        arch_source_str = "PYTORCH_ROCM_ARCH env variable"

    # List are separated by ; or space.
    pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))

    # Filter out the invalid architectures and print a warning.
    arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)

    # If none of the specified architectures are valid, raise an error.
    if not arch_list:
        raise RuntimeError(
            f"None of the ROCM architectures in {arch_source_str} "
            f"({env_arch_list}) is supported. "
            f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
    invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
    if invalid_arch_list:
        warnings.warn(
            f"Unsupported ROCM architectures ({invalid_arch_list}) are "
            f"excluded from the {arch_source_str} output "
            f"({env_arch_list}). Supported ROCM architectures are: "
            f"{ROCM_SUPPORTED_ARCHS}.",
            stacklevel=2)
    return arch_list


185
186
187
188
189
190
191
def get_torch_arch_list() -> Set[str]:
    # TORCH_CUDA_ARCH_LIST can have one or more architectures,
    # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
    # compiler to additionally include PTX code that can be runtime-compiled
    # and executed on the 8.6 or newer architectures. While the PTX code will
    # not give the best performance on the newer architectures, it provides
    # forward compatibility.
192
193
    env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
    if env_arch_list is None:
194
195
196
        return set()

    # List are separated by ; or space.
197
198
199
200
201
    torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
    if not torch_arch_list:
        return set()

    # Filter out the invalid architectures and print a warning.
202
203
204
    valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
        {s + "+PTX"
         for s in NVIDIA_SUPPORTED_ARCHS})
205
206
207
208
    arch_list = torch_arch_list.intersection(valid_archs)
    # If none of the specified architectures are valid, raise an error.
    if not arch_list:
        raise RuntimeError(
209
            "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
210
            f"variable ({env_arch_list}) is supported. "
211
            f"Supported CUDA architectures are: {valid_archs}.")
212
213
214
    invalid_arch_list = torch_arch_list - valid_archs
    if invalid_arch_list:
        warnings.warn(
215
            f"Unsupported CUDA architectures ({invalid_arch_list}) are "
216
            "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
217
            f"({env_arch_list}). Supported CUDA architectures are: "
218
219
            f"{valid_archs}.",
            stacklevel=2)
220
    return arch_list
221
222


223
224
225
226
227
228
229
if _is_hip():
    rocm_arches = get_pytorch_rocm_arch()
    NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
else:
    # First, check the TORCH_CUDA_ARCH_LIST environment variable.
    compute_capabilities = get_torch_arch_list()

230
if _is_cuda() and not compute_capabilities:
231
232
233
234
235
236
237
238
239
    # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
    # GPUs on the current machine.
    device_count = torch.cuda.device_count()
    for i in range(device_count):
        major, minor = torch.cuda.get_device_capability(i)
        if major < 7:
            raise RuntimeError(
                "GPUs with compute capability below 7.0 are not supported.")
        compute_capabilities.add(f"{major}.{minor}")
240

241
242
ext_modules = []

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
if _is_cuda():
    nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
    if not compute_capabilities:
        # If no GPU is specified nor available, add all supported architectures
        # based on the NVCC CUDA version.
        compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
        if nvcc_cuda_version < Version("11.1"):
            compute_capabilities.remove("8.6")
        if nvcc_cuda_version < Version("11.8"):
            compute_capabilities.remove("8.9")
            compute_capabilities.remove("9.0")
    # Validate the NVCC CUDA version.
    if nvcc_cuda_version < Version("11.0"):
        raise RuntimeError(
            "CUDA 11.0 or higher is required to build the package.")
    if (nvcc_cuda_version < Version("11.1")
            and any(cc.startswith("8.6") for cc in compute_capabilities)):
        raise RuntimeError(
            "CUDA 11.1 or higher is required for compute capability 8.6.")
262
    if nvcc_cuda_version < Version("11.8"):
263
        if any(cc.startswith("8.9") for cc in compute_capabilities):
264
265
266
267
268
            # CUDA 11.8 is required to generate the code targeting compute
            # capability 8.9. However, GPUs with compute capability 8.9 can
            # also run the code generated by the previous versions of CUDA 11
            # and targeting compute capability 8.0. Therefore, if CUDA 11.8
            # is not available, we target compute capability 8.0 instead of 8.9.
269
270
271
272
273
274
275
276
277
278
279
            warnings.warn(
                "CUDA 11.8 or higher is required for compute capability 8.9. "
                "Targeting compute capability 8.0 instead.",
                stacklevel=2)
            compute_capabilities = set(cc for cc in compute_capabilities
                                       if not cc.startswith("8.9"))
            compute_capabilities.add("8.0+PTX")
        if any(cc.startswith("9.0") for cc in compute_capabilities):
            raise RuntimeError(
                "CUDA 11.8 or higher is required for compute capability 9.0.")

280
281
    NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy()

282
283
284
285
286
287
288
289
    # Add target compute capabilities to NVCC flags.
    for capability in compute_capabilities:
        num = capability[0] + capability[2]
        NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
        if capability.endswith("+PTX"):
            NVCC_FLAGS += [
                "-gencode", f"arch=compute_{num},code=compute_{num}"
            ]
290
291
292
293
294
295
296
297
        if int(capability[0]) >= 8:
            NVCC_FLAGS_PUNICA += [
                "-gencode", f"arch=compute_{num},code=sm_{num}"
            ]
            if capability.endswith("+PTX"):
                NVCC_FLAGS_PUNICA += [
                    "-gencode", f"arch=compute_{num},code=compute_{num}"
                ]
298
299
300
301
302
303
304

    # Use NVCC threads to parallelize the build.
    if nvcc_cuda_version >= Version("11.2"):
        nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
        num_threads = min(os.cpu_count(), nvcc_threads)
        NVCC_FLAGS += ["--threads", str(num_threads)]

305
306
307
    if nvcc_cuda_version >= Version("11.8"):
        NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]

308
309
310
311
312
313
314
315
316
317
318
319
    # changes for punica kernels
    NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
    REMOVE_NVCC_FLAGS = [
        '-D__CUDA_NO_HALF_OPERATORS__',
        '-D__CUDA_NO_HALF_CONVERSIONS__',
        '-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
        '-D__CUDA_NO_HALF2_OPERATORS__',
    ]
    for flag in REMOVE_NVCC_FLAGS:
        with contextlib.suppress(ValueError):
            torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)

320
    install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    device_count = torch.cuda.device_count()
    for i in range(device_count):
        major, minor = torch.cuda.get_device_capability(i)
        if major < 8:
            install_punica = False
            break
    if install_punica:
        ext_modules.append(
            CUDAExtension(
                name="vllm._punica_C",
                sources=["csrc/punica/punica_ops.cc"] +
                glob("csrc/punica/bgmv/*.cu"),
                extra_compile_args={
                    "cxx": CXX_FLAGS,
                    "nvcc": NVCC_FLAGS_PUNICA,
                },
            ))
338
339
elif _is_neuron():
    neuronxcc_version = get_neuronxcc_version()
340

341
342
343
344
345
346
347
vllm_extension_sources = [
    "csrc/cache_kernels.cu",
    "csrc/attention/attention_kernels.cu",
    "csrc/pos_encoding_kernels.cu",
    "csrc/activation_kernels.cu",
    "csrc/layernorm_kernels.cu",
    "csrc/quantization/squeezellm/quant_cuda_kernel.cu",
kliuae's avatar
kliuae committed
348
    "csrc/quantization/gptq/q_gemm.cu",
349
    "csrc/cuda_utils_kernels.cu",
350
    "csrc/moe_align_block_size_kernels.cu",
351
352
353
354
355
    "csrc/pybind.cpp",
]

if _is_cuda():
    vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
356
357
    vllm_extension_sources.append(
        "csrc/quantization/marlin/marlin_cuda_kernel.cu")
358
    vllm_extension_sources.append("csrc/custom_all_reduce.cu")
Woosuk Kwon's avatar
Woosuk Kwon committed
359

360
361
362
363
364
365
366
367
368
369
370
    # Add MoE kernels.
    ext_modules.append(
        CUDAExtension(
            name="vllm._moe_C",
            sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
            extra_compile_args={
                "cxx": CXX_FLAGS,
                "nvcc": NVCC_FLAGS,
            },
        ))

371
372
373
374
375
376
377
378
if not _is_neuron():
    vllm_extension = CUDAExtension(
        name="vllm._C",
        sources=vllm_extension_sources,
        extra_compile_args={
            "cxx": CXX_FLAGS,
            "nvcc": NVCC_FLAGS,
        },
379
        libraries=["cuda"] if _is_cuda() else [],
380
381
    )
    ext_modules.append(vllm_extension)
382

383

384
385
386
387
def get_path(*filepath) -> str:
    return os.path.join(ROOT_DIR, *filepath)


388
def find_version(filepath: str) -> str:
389
390
391
392
393
    """Extract version information from the given filepath.

    Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
    """
    with open(filepath) as fp:
Woosuk Kwon's avatar
Woosuk Kwon committed
394
395
        version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
                                  fp.read(), re.M)
396
397
398
399
400
        if version_match:
            return version_match.group(1)
        raise RuntimeError("Unable to find version string.")


401
402
def get_vllm_version() -> str:
    version = find_version(get_path("vllm", "__init__.py"))
403

404
405
406
407
408
409
    if _is_cuda():
        cuda_version = str(nvcc_cuda_version)
        if cuda_version != MAIN_CUDA_VERSION:
            cuda_version_str = cuda_version.replace(".", "")[:3]
            version += f"+cu{cuda_version_str}"
    elif _is_hip():
410
411
412
413
414
        # Get the HIP version
        hipcc_version = get_hipcc_rocm_version()
        if hipcc_version != MAIN_CUDA_VERSION:
            rocm_version_str = hipcc_version.replace(".", "")[:3]
            version += f"+rocm{rocm_version_str}"
415
416
417
418
419
420
    elif _is_neuron():
        # Get the Neuron version
        neuron_version = str(neuronxcc_version)
        if neuron_version != MAIN_CUDA_VERSION:
            neuron_version_str = neuron_version.replace(".", "")[:3]
            version += f"+neuron{neuron_version_str}"
421
    else:
422
        raise RuntimeError("Unknown runtime environment")
423

424
425
426
    return version


427
def read_readme() -> str:
Stephen Krider's avatar
Stephen Krider committed
428
429
430
431
432
433
    """Read the README file if present."""
    p = get_path("README.md")
    if os.path.isfile(p):
        return io.open(get_path("README.md"), "r", encoding="utf-8").read()
    else:
        return ""
434
435


436
437
def get_requirements() -> List[str]:
    """Get Python package dependencies from requirements.txt."""
438
    if _is_cuda():
439
440
        with open(get_path("requirements.txt")) as f:
            requirements = f.read().strip().split("\n")
441
442
443
444
445
446
        if nvcc_cuda_version <= Version("11.8"):
            # replace cupy-cuda12x with cupy-cuda11x for cuda 11.x
            for i in range(len(requirements)):
                if requirements[i].startswith("cupy-cuda12x"):
                    requirements[i] = "cupy-cuda11x"
                    break
447
448
449
450
451
452
453
454
455
456
    elif _is_hip():
        with open(get_path("requirements-rocm.txt")) as f:
            requirements = f.read().strip().split("\n")
    elif _is_neuron():
        with open(get_path("requirements-neuron.txt")) as f:
            requirements = f.read().strip().split("\n")
    else:
        raise ValueError(
            "Unsupported platform, please use CUDA, ROCM or Neuron.")

457
458
459
    return requirements


460
461
462
package_data = {
    "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
}
Simon Mo's avatar
Simon Mo committed
463
464
465
466
if os.environ.get("VLLM_USE_PRECOMPILED"):
    ext_modules = []
    package_data["vllm"].append("*.so")

Woosuk Kwon's avatar
Woosuk Kwon committed
467
setuptools.setup(
Woosuk Kwon's avatar
Woosuk Kwon committed
468
    name="vllm",
469
    version=get_vllm_version(),
Woosuk Kwon's avatar
Woosuk Kwon committed
470
    author="vLLM Team",
471
    license="Apache 2.0",
Woosuk Kwon's avatar
Woosuk Kwon committed
472
473
    description=("A high-throughput and memory-efficient inference and "
                 "serving engine for LLMs"),
474
475
    long_description=read_readme(),
    long_description_content_type="text/markdown",
476
    url="https://github.com/vllm-project/vllm",
477
    project_urls={
478
479
        "Homepage": "https://github.com/vllm-project/vllm",
        "Documentation": "https://vllm.readthedocs.io/en/latest/",
480
481
482
483
484
    },
    classifiers=[
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
Woosuk Kwon's avatar
Woosuk Kwon committed
485
        "Programming Language :: Python :: 3.11",
486
487
488
        "License :: OSI Approved :: Apache Software License",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
    ],
Woosuk Kwon's avatar
Woosuk Kwon committed
489
490
    packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
                                               "examples", "tests")),
491
492
    python_requires=">=3.8",
    install_requires=get_requirements(),
Woosuk Kwon's avatar
Woosuk Kwon committed
493
    ext_modules=ext_modules,
494
    cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
Simon Mo's avatar
Simon Mo committed
495
    package_data=package_data,
Woosuk Kwon's avatar
Woosuk Kwon committed
496
)