The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM.
## How Plugins Work in vLLM
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [](#arch-overview)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work.
## How vLLM Discovers Plugins
vLLM's plugin system uses the standard Python `entry_points` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin:
For more information on adding entry points to your package, please check the [official documentation](https://setuptools.pypa.io/en/latest/userguide/entry_point.html).
Every plugin has three parts:
1.**Plugin group**: The name of the entry point group. vLLM uses the entry point group `vllm.general_plugins` to register general plugins. This is the key of `entry_points` in the `setup.py` file. Always use `vllm.general_plugins` for vLLM's general plugins.
2.**Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
3.**Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
## What Can Plugins Do?
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
## Guidelines for Writing Plugins
-**Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
## Compatibility Guarantee
vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM.
How Plugins Work in vLLM
------------------------
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`arch_overview`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins <https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16>`__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work.
How vLLM Discovers Plugins
--------------------------
vLLM's plugin system uses the standard Python ``entry_points`` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin:
if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava")
For more information on adding entry points to your package, please check the `official documentation <https://setuptools.pypa.io/en/latest/userguide/entry_point.html>`__.
Every plugin has three parts:
1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group ``vllm.general_plugins`` to register general plugins. This is the key of ``entry_points`` in the ``setup.py`` file. Always use ``vllm.general_plugins`` for vLLM's general plugins.
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the ``entry_points`` dictionary. In the example above, the plugin name is ``register_dummy_model``. Plugins can be filtered by their names using the ``VLLM_PLUGINS`` environment variable. To load only a specific plugin, set ``VLLM_PLUGINS`` to the plugin name.
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is ``vllm_add_dummy_model:register``, which refers to a function named ``register`` in the ``vllm_add_dummy_model`` module.
What Can Plugins Do?
--------------------
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Guidelines for Writing Plugins
------------------------------
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
Compatibility Guarantee
-----------------------
vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
1.[Build from source with docker](#build-from-source-docker-rocm)
2.[Build from source](#build-from-source-rocm)
(build-from-source-docker-rocm)=
## Option 1: Build from source with docker (recommended)
You can build and install vLLM from source.
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
```console
{
"features": {
"buildkit": true
}
}
```
<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
It provides flexibility to customize the build of docker image using the following arguments:
-`BASE_IMAGE`: specifies the base image used when running `docker build`, specifically the PyTorch on ROCm base image.
-`BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For [Radeon RX 7900 series (gfx1100)](https://rocm.docs.amd.com/projects/radeon/en/latest/index.html), this should be set to 0 before flash-attention supports this target.
-`FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
-`FA_BRANCH`: specifies the branch used to build the CK flash-attention in [ROCm's flash-attention repo](https://github.com/ROCmSoftwarePlatform/flash-attention). The default is `ae7928c`
-`BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1.
Their values can be passed in when running `docker build` with `--build-arg` options.
To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/)
1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton)
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md)
- If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent.
```
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)
Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
Alternatively, wheels intended for vLLM use can be accessed under the releases.
For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
This may take 5-10 minutes. Currently, {code}`pip install .` does not work for ROCm installation.
```{tip}
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
- To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention.
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
```
```{tip}
- For MI300x (gfx942) users, to achieve optimal performance, please refer to [MI300x tuning guide](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html) for performance optimization and tuning tips on system and workflow level.
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. This guide provides installation instructions specific to ARM. For additional details on supported features, refer to the x86 platform documentation covering:
* CPU backend inference capabilities
* Relevant runtime environment variables
* Performance optimization tips
- CPU backend inference capabilities
- Relevant runtime environment variables
- Performance optimization tips
ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
Contents:
1. :ref:`Requirements <arm_backend_requirements>`
2. :ref:`Quick Start with Dockerfile <arm_backend_quick_start_dockerfile>`
3. :ref:`Building from Source <build_arm_backend_from_source>`
1.[Requirements](#arm-backend-requirements)
2.[Quick Start with Dockerfile](#arm-backend-quick-start-dockerfile)
3.[Building from Source](#build-arm-backend-from-source)
.. _arm_backend_requirements:
(arm-backend-requirements)=
Requirements
------------
## Requirements
* **Operating System**: Linux or macOS
* **Compiler**: gcc/g++ >= 12.3.0 (optional, but recommended)
* **Instruction Set Architecture (ISA)**: NEON support is required
-**Operating System**: Linux or macOS
-**Compiler**: gcc/g++ >= 12.3.0 (optional, but recommended)
-**Instruction Set Architecture (ISA)**: NEON support is required
To build vLLM from source on Ubuntu 22.04 or other Linux distributions, follow a similar process as with x86. Testing has been conducted on AWS Graviton3 instances for compatibility.
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
- Tensor Parallel
- Model Quantization (`INT8 W8A8, AWQ`)
- Chunked-prefill
- Prefix-caching
- FP8-E5M2 KV-Caching (TODO)
Table of contents:
1.[Requirements](#cpu-backend-requirements)
2.[Quick start using Dockerfile](#cpu-backend-quick-start-dockerfile)
3.[Build from source](#build-cpu-backend-from-source)
- First, install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
```
(env-intro)=
## Related runtime environment variables
-`VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
-`VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
(ipex-guidance)=
## Intel Extension for PyTorch
-[Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
(cpu-backend-performance-tips)=
## Performance tips
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
$find / -name*libtcmalloc*# find the dynamic link library path
$export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD# prepend the library to LD_PRELOAD
$python examples/offline_inference.py # run vLLM
```
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:
```console
$export VLLM_CPU_KVCACHE_SPACE=40
$export VLLM_CPU_OMP_THREADS_BIND=0-29
$vllm serve facebook/opt-125m
```
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND`. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
```console
$lscpu -e# check the mapping between logical CPU cores and physical CPU cores
#The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
#On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
$export VLLM_CPU_OMP_THREADS_BIND=0-7
$python examples/offline_inference.py
```
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using `VLLM_CPU_OMP_THREADS_BIND` to avoid cross NUMA node memory access.
## CPU Backend Considerations
- The CPU backend significantly differs from the GPU backend since the vLLM architecture was originally optimized for GPU use. A number of optimizations are needed to enhance its performance.
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa). For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel.
- Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With [TP feature on CPU](gh-pr:6125) merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
- Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like [Nginx](../serving/deploying_with_nginx.md) or HAProxy are recommended. Anyscale Ray project provides the feature on LLM [serving](https://docs.ray.io/en/latest/serve/index.html). Here is the example to setup a scalable LLM serving with [Ray Serve](https://github.com/intel/llm-on-ray/blob/main/docs/setup.md).
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
- Tensor Parallel
- Model Quantization (``INT8 W8A8, AWQ``)
- Chunked-prefill
- Prefix-caching
- FP8-E5M2 KV-Caching (TODO)
Table of contents:
#. :ref:`Requirements <cpu_backend_requirements>`
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
#. :ref:`Build from source <build_cpu_backend_from_source>`
- First, install recommended compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
.. _env_intro:
Related runtime environment variables
-------------------------------------
- ``VLLM_CPU_KVCACHE_SPACE``: specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- ``VLLM_CPU_OMP_THREADS_BIND``: specify the CPU cores dedicated to the OpenMP threads. For example, ``VLLM_CPU_OMP_THREADS_BIND=0-31`` means there will be 32 OpenMP threads bound on 0-31 CPU cores. ``VLLM_CPU_OMP_THREADS_BIND=0-31|32-63`` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
.. _ipex_guidance:
Intel Extension for PyTorch
---------------------------
- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
.. _cpu_backend_performance_tips:
Performance tips
-----------------
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
$ find / -name *libtcmalloc* # find the dynamic link library path
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
$ python examples/offline_inference.py # run vLLM
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:
.. code-block:: console
$ export VLLM_CPU_KVCACHE_SPACE=40
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
$ vllm serve facebook/opt-125m
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using ``VLLM_CPU_OMP_THREADS_BIND``. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
.. code-block:: console
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores
# The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
$ python examples/offline_inference.py
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access.
CPU Backend Considerations
--------------------------
- The CPU backend significantly differs from the GPU backend since the vLLM architecture was originally optimized for GPU use. A number of optimizations are needed to enhance its performance.
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the `topology <https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa>`_. For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel.
* Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With `TP feature on CPU <https://github.com/vllm-project/vllm/pull/6125>`_ merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
* Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like `Nginx <../serving/deploying_with_nginx.html>`_ or HAProxy are recommended. Anyscale Ray project provides the feature on LLM `serving <https://docs.ray.io/en/latest/serve/index.html>`_. Here is the example to setup a scalable LLM serving with `Ray Serve <https://github.com/intel/llm-on-ray/blob/main/docs/setup.md>`_.
This document outlines some debugging strategies you can consider. If you think you've discovered a bug, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible.
```{note}
Once you've debugged a problem, remember to turn off any debugging environment variables defined, or simply start a new shell to avoid being affected by lingering debugging settings. Otherwise, the system might be slow with debugging functionalities left activated.
```
## Hangs downloading a model
If the model isn't already downloaded to disk, vLLM will download it from the internet which can take time and depend on your internet connection.
It's recommended to download the model first using the [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli) and passing the local path to the model to vLLM. This way, you can isolate the issue.
## Hangs loading a model from disk
If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow.
It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory.
```{note}
To isolate the model downloading and loading issue, you can use the `--load-format dummy` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck.
```
## Model is too large
If the model is too large to fit in a single GPU, you might want to [consider tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
## Enable more logging
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
-`export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging.
-`export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem.
-`export NCCL_DEBUG=TRACE` to turn on more logging for NCCL.
-`export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs.
## Incorrect network setup
The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as `DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl` and the IP address should be the correct one.
If it's not, override the IP address using the environment variable `export VLLM_HOST_IP=<your_ip_address>`.
You might also need to set `export NCCL_SOCKET_IFNAME=<your_network_interface>` and `export GLOO_SOCKET_IFNAME=<your_network_interface>` to specify the network interface for the IP address.
## Error near `self.graph.replay()`
If vLLM crashes and the error trace captures it somewhere around `self.graph.replay()` in `vllm/worker/model_runner.py`, it is a CUDA error inside CUDAGraph.
To identify the particular CUDA operation that causes the error, you can add `--enforce-eager` to the command line, or `enforce_eager=True` to the {class}`~vllm.LLM` class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error.
## Incorrect hardware/driver
If GPU/CPU communication cannot be established, you can use the following Python script and follow the instructions below to confirm whether the GPU/CPU communication is working correctly.
If you are testing with multi-nodes, adjust `--nproc-per-node` and `--nnodes` according to your setup and set `MASTER_ADDR` to the correct IP address of the master node, reachable from all nodes. Then, run:
If the script runs successfully, you should see the message `sanity check is successful!`.
If the test script hangs or crashes, usually it means the hardware/drivers are broken in some sense. You should try to contact your system administrator or hardware vendor for further assistance. As a common workaround, you can try to tune some NCCL environment variables, such as `export NCCL_P2P_DISABLE=1` to see if it helps. Please check [their documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) for more information. Please only use these environment variables as a temporary workaround, as they might affect the performance of the system. The best solution is still to fix the hardware/drivers so that the test script can run successfully.
```{note}
A multi-node environment is more complicated than a single-node one. If you see errors such as `torch.distributed.DistNetworkError`, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments:
- In the first node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py`.
- In the second node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py`.
Adjust `--nproc-per-node`, `--nnodes`, and `--node-rank` according to your setup, being sure to execute different commands (with different `--node-rank`) on different nodes.
```
(debugging-python-multiprocessing)=
## Python multiprocessing
### `RuntimeError` Exception
If you have seen a warning in your logs like this:
```console
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
initialized. We must use the `spawn` multiprocessing start method. Setting
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
To fix this issue, refer to the "Safe importing of main module"
section in https://docs.python.org/3/library/multiprocessing.html
```
then you must update your Python code to guard usage of `vllm` behind a `if
__name__ == '__main__':` block. For example, instead of this:
```python
importvllm
llm=vllm.LLM(...)
```
try this instead:
```python
if__name__=='__main__':
importvllm
llm=vllm.LLM(...)
```
## Known Issues
- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759).
- To circumvent a NCCL [bug](https://github.com/NVIDIA/nccl/issues/1234) , all vLLM processes will set an environment variable ``NCCL_CUMEM_ENABLE=0`` to disable NCCL's ``cuMem`` allocator. It does not affect performance but only gives memory benefits. When external processes want to set up a NCCL connection with vLLM's processes, they should also set this environment variable, otherwise, inconsistent environment setup will cause NCCL to hang or crash, as observed in the [RLHF integration](https://github.com/OpenRLHF/OpenRLHF/pull/604) and the [discussion](gh-issue:5723#issuecomment-2554389656) .
This document outlines some debugging strategies you can consider. If you think you've discovered a bug, please `search existing issues <https://github.com/vllm-project/vllm/issues?q=is%3Aissue>`_ first to see if it has already been reported. If not, please `file a new issue <https://github.com/vllm-project/vllm/issues/new/choose>`_, providing as much relevant information as possible.
.. note::
Once you've debugged a problem, remember to turn off any debugging environment variables defined, or simply start a new shell to avoid being affected by lingering debugging settings. Otherwise, the system might be slow with debugging functionalities left activated.
Hangs downloading a model
----------------------------------------
If the model isn't already downloaded to disk, vLLM will download it from the internet which can take time and depend on your internet connection.
It's recommended to download the model first using the `huggingface-cli <https://huggingface.co/docs/huggingface_hub/en/guides/cli>`_ and passing the local path to the model to vLLM. This way, you can isolate the issue.
Hangs loading a model from disk
----------------------------------------
If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow.
It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory.
.. note::
To isolate the model downloading and loading issue, you can use the ``--load-format dummy`` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck.
Model is too large
----------------------------------------
If the model is too large to fit in a single GPU, you might want to `consider tensor parallelism <https://docs.vllm.ai/en/latest/serving/distributed_serving.html#distributed-inference-and-serving>`_ to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `this example <https://docs.vllm.ai/en/latest/getting_started/examples/save_sharded_state.html>`_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
Enable more logging
----------------------------------------
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
- ``export VLLM_LOGGING_LEVEL=DEBUG`` to turn on more logging.
- ``export CUDA_LAUNCH_BLOCKING=1`` to identify which CUDA kernel is causing the problem.
- ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
- ``export VLLM_TRACE_FUNCTION=1`` to record all function calls for inspection in the log files to tell which function crashes or hangs.
Incorrect network setup
----------------------------------------
The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl`` and the IP address should be the correct one.
If it's not, override the IP address using the environment variable ``export VLLM_HOST_IP=<your_ip_address>``.
You might also need to set ``export NCCL_SOCKET_IFNAME=<your_network_interface>`` and ``export GLOO_SOCKET_IFNAME=<your_network_interface>`` to specify the network interface for the IP address.
Error near ``self.graph.replay()``
----------------------------------------
If vLLM crashes and the error trace captures it somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a CUDA error inside CUDAGraph.
To identify the particular CUDA operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error.
Incorrect hardware/driver
----------------------------------------
If GPU/CPU communication cannot be established, you can use the following Python script and follow the instructions below to confirm whether the GPU/CPU communication is working correctly.
If you are testing with multi-nodes, adjust ``--nproc-per-node`` and ``--nnodes`` according to your setup and set ``MASTER_ADDR`` to the correct IP address of the master node, reachable from all nodes. Then, run:
If the script runs successfully, you should see the message ``sanity check is successful!``.
If the test script hangs or crashes, usually it means the hardware/drivers are broken in some sense. You should try to contact your system administrator or hardware vendor for further assistance. As a common workaround, you can try to tune some NCCL environment variables, such as ``export NCCL_P2P_DISABLE=1`` to see if it helps. Please check `their documentation <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`__ for more information. Please only use these environment variables as a temporary workaround, as they might affect the performance of the system. The best solution is still to fix the hardware/drivers so that the test script can run successfully.
.. note::
A multi-node environment is more complicated than a single-node one. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments:
- In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``.
- In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``.
Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup, being sure to execute different commands (with different ``--node-rank``) on different nodes.
Python multiprocessing
----------------------
`RuntimeError` Exception
^^^^^^^^^^^^^^^^^^^^^^^^
If you have seen a warning in your logs like this:
.. code-block:: console
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
initialized. We must use the `spawn` multiprocessing start method. Setting
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
To fix this issue, refer to the "Safe importing of main module"
section in https://docs.python.org/3/library/multiprocessing.html
then you must update your Python code to guard usage of ``vllm`` behind a ``if
__name__ == '__main__':`` block. For example, instead of this:
.. code-block:: python
import vllm
llm = vllm.LLM(...)
try this instead:
.. code-block:: python
if __name__ == '__main__':
import vllm
llm = vllm.LLM(...)
Known Issues
----------------------------------------
- In ``v0.5.2``, ``v0.5.3``, and ``v0.5.3.post1``, there is a bug caused by `zmq <https://github.com/zeromq/pyzmq/issues/2000>`_ , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of ``vllm`` to include the `fix <https://github.com/vllm-project/vllm/pull/6759>`_.