optimization.md 11.6 KB
Newer Older
1
# Optimization and Tuning
2

3
4
This guide covers optimization strategies and performance tuning for vLLM V1.

5
6
7
!!! tip
    Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory.

8
## Preemption
9

10
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
11
12
In such cases, vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes
available again. When this occurs, you may see the following warning:
13

14
```text
15
WARNING 05-09 00:49:33 scheduler.py:1057 Sequence group 0 is preempted by PreemptionMode.RECOMPUTE mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1
16
17
18
```

While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency.
19
20
21
22
23
24
If you frequently encounter preemptions, consider the following actions:

- Increase `gpu_memory_utilization`. vLLM pre-allocates GPU cache using this percentage of memory. By increasing utilization, you can provide more KV cache space.
- Decrease `max_num_seqs` or `max_num_batched_tokens`. This reduces the number of concurrent requests in a batch, thereby requiring less KV cache space.
- Increase `tensor_parallel_size`. This shards model weights across GPUs, allowing each GPU to have more memory available for KV cache. However, increasing this value may cause excessive synchronization overhead.
- Increase `pipeline_parallel_size`. This distributes model layers across GPUs, reducing the memory needed for model weights on each GPU, indirectly leaving more memory available for KV cache. However, increasing this value may cause latency penalties.
25

26
You can monitor the number of preemption requests through Prometheus metrics exposed by vLLM. Additionally, you can log the cumulative number of preemption requests by setting `disable_log_stats=False`.
27

28
In vLLM V1, the default preemption mode is `RECOMPUTE` rather than `SWAP`, as recomputation has lower overhead in the V1 architecture.
29

30
[](){ #chunked-prefill }
31

32
## Chunked Prefill
33

34
35
36
37
38
39
40
41
42
43
Chunked prefill allows vLLM to process large prefills in smaller chunks and batch them together with decode requests. This feature helps improve both throughput and latency by better balancing compute-bound (prefill) and memory-bound (decode) operations.

In vLLM V1, **chunked prefill is always enabled by default**. This is different from vLLM V0, where it was conditionally enabled based on model characteristics.

With chunked prefill enabled, the scheduling policy prioritizes decode requests. It batches all pending decode requests before scheduling any prefill operations. When there are available tokens in the `max_num_batched_tokens` budget, it schedules pending prefills. If a pending prefill request cannot fit into `max_num_batched_tokens`, it automatically chunks it.

This policy has two benefits:

- It improves ITL and generation decode because decode requests are prioritized.
- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch.
44

45
46
47
48
49
50
### Performance Tuning with Chunked Prefill

You can tune the performance by adjusting `max_num_batched_tokens`:

- Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes.
- Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch.
Tialo's avatar
Tialo committed
51
- For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs.
52
- If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes).
53

54
```python
Reid's avatar
Reid committed
55
56
from vllm import LLM

57
58
# Set max_num_batched_tokens to tune performance
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", max_num_batched_tokens=16384)
59
```
60

61
See related papers for more details (<https://arxiv.org/pdf/2401.08671> or <https://arxiv.org/pdf/2308.16369>).
62

63
## Parallelism Strategies
64

65
vLLM supports multiple parallelism strategies that can be combined to optimize performance across different hardware configurations.
66

67
### Tensor Parallelism (TP)
68

69
Tensor parallelism shards model parameters across multiple GPUs within each model layer. This is the most common strategy for large model inference within a single node.
Simon Mo's avatar
Simon Mo committed
70

71
**When to use:**
Simon Mo's avatar
Simon Mo committed
72

73
74
- When the model is too large to fit on a single GPU
- When you need to reduce memory pressure per GPU to allow more KV cache space for higher throughput
Simon Mo's avatar
Simon Mo committed
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
```python
from vllm import LLM

# Split model across 4 GPUs
llm = LLM(model="meta-llama/Llama-3.3-70B-Instruct", tensor_parallel_size=4)
```

For models that are too large to fit on a single GPU (like 70B parameter models), tensor parallelism is essential.

### Pipeline Parallelism (PP)

Pipeline parallelism distributes model layers across multiple GPUs. Each GPU processes different parts of the model in sequence.

**When to use:**

- When you've already maxed out efficient tensor parallelism but need to distribute the model further, or across nodes
- For very deep and narrow models where layer distribution is more efficient than tensor sharding

Pipeline parallelism can be combined with tensor parallelism for very large models:

```python
from vllm import LLM

# Combine pipeline and tensor parallelism
llm = LLM(
    model="meta-llama/Llama-3.3-70B-Instruct,
    tensor_parallel_size=4,
    pipeline_parallel_size=2
)
```

### Expert Parallelism (EP)

Expert parallelism is a specialized form of parallelism for Mixture of Experts (MoE) models, where different expert networks are distributed across GPUs.

**When to use:**
112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
- Specifically for MoE models (like DeepSeekV3, Qwen3MoE, Llama-4)
- When you want to balance the expert computation load across GPUs

Expert parallelism is enabled by setting `enable_expert_parallel=True`, which will use expert parallelism instead of tensor parallelism for MoE layers.
It will use the same degree of parallelism as what you have set for tensor parallelism.

### Data Parallelism (DP)

Data parallelism replicates the entire model across multiple GPU sets and processes different batches of requests in parallel.

**When to use:**

- When you have enough GPUs to replicate the entire model
- When you need to scale throughput rather than model size
- In multi-user environments where isolation between request batches is beneficial

Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
### Batch-level DP for Multi-Modal Encoders

By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
in order to reduce the memory and compute load on each GPU.

However, since the size of multi-modal encoders is very small compared to language decoders,
there is relatively little gain from TP. On the other hand, TP incurs significant communication
overhead because of all-reduce being performed after every layer.

Given this, it may be advantageous to instead shard the batched input data using TP, essentially
performing batch-level DP. This has been shown to improve the throughput by around 10% for
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
batch-level DP can provide another 40% increase to throughput compared to regular TP.

Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.

You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:

```python
from vllm import LLM

llm = LLM(
    model="Qwen/Qwen2.5-VL-72B-Instruct",
    tensor_parallel_size=4,
157
158
159
160
    # When mm_encoder_tp_mode="data",
    # the vision encoder uses TP=4 (not DP=1) to shard the input data,
    # so the TP size becomes the effective DP size.
    # Note that this is independent of the DP size for language decoder which is used in expert parallel setting.
161
    mm_encoder_tp_mode="data",
162
163
    # The language decoder uses TP=4 to shard the weights regardless
    # of the setting of mm_encoder_tp_mode
164
165
166
)
```

167
!!! important
168
169
170
    Batch-level DP is not to be confused with API request-level DP
    (which is instead controlled by `data_parallel_size`).

171
172
173
174
175
Batch-level DP needs to be implemented on a per-model basis,
and enabled by setting `supports_encoder_tp_data = True` in the model class.
Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature.

Known supported models:
176
177

- Llama4 (<gh-pr:18368>)
178
- MiniCPM-V-4 (<gh-pr:23327>)
179
180
181
- Qwen2.5-VL (<gh-pr:22742>)
- Step3 (<gh-pr:22697>)

182
## Input Processing
183

184
### Parallel Processing
185

186
187
188
189
You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing).
This is useful when input processing (which is run inside the API server)
becomes a bottleneck compared to model execution (which is run inside engine core)
and you have excess CPU capacity.
190

191
192
193
```console
# Run 4 API processes and 1 engine core process
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4
194

195
196
# Run 4 API processes and 2 engine core processes
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
197
198
```

199
200
!!! note
    API server scale-out is only available for online inference.
201

202
203
204
205
206
207
208
!!! warning
    By default, 8 CPU threads are used in each API server to load media items (e.g. images)
    from request data.

    If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT`
    to avoid CPU resource exhaustion.

209
!!! note
210
211
    API server scale-out disables [multi-modal IPC caching](#ipc-caching)
    because it requires a one-to-one correspondance between API and engine core processes.
212

213
    This does not impact [multi-modal processor caching](#processor-caching).
214

215
## Multi-Modal Caching
216

217
Multi-modal caching avoids repeated transfer or processing of the same multi-modal data,
218
which commonly occurs in multi-turn conversations.
219

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
### Processor Caching

Multi-modal processor caching is automatically enabled
to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`.

### IPC Caching

Multi-modal IPC caching is automatically enabled when
there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes,
to avoid repeatedly transferring the same multi-modal inputs between them.

### Configuration

You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).

If you do not benefit much from the cache, you can disable both IPC
and processor caching completely via `mm_processor_cache_gb=0`.
237

238
Examples:
239
240

```python
241
# Use a larger cache
242
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
243
244
245
246
247
          mm_processor_cache_gb=8)

# Disable the cache
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
          mm_processor_cache_gb=0)
248
```
249
250
251
252
253
254
255
256
257
258
259
260
261

### Cache Placement

Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:

| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|-------------------|-------------|------------|------------|-------------|
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
| ❌ | ❌ | N/A | N/A | `0` |

K: Stores the hashes of multi-modal items  
V: Stores the processed tensor data of multi-modal items