offline_inference.md 6.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
(offline-inference)=

# Offline Inference

You can run vLLM in your own code on a list of prompts.

The offline API is based on the {class}`~vllm.LLM` class.
To initialize the vLLM engine, create a new instance of `LLM` and specify the model to run.

For example, the following code downloads the [`facebook/opt-125m`](https://huggingface.co/facebook/opt-125m) model from HuggingFace
and runs it in vLLM using the default configuration.

```python
Reid's avatar
Reid committed
14
15
from vllm import LLM

16
17
18
19
20
21
22
23
24
25
26
llm = LLM(model="facebook/opt-125m")
```

After initializing the `LLM` instance, you can perform model inference using various APIs.
The available APIs depend on the type of model that is being run:

- [Generative models](#generative-models) output logprobs which are sampled from to obtain the final output text.
- [Pooling models](#pooling-models) output their hidden states directly.

Please refer to the above pages for more details about each API.

27
:::{seealso}
28
[API Reference](/api/offline_inference/index)
29
:::
30
31
32
33
34
35

## Configuration Options

This section lists the most common options for running the vLLM engine.
For a full list, refer to the [Engine Arguments](#engine-args) page.

36
37
(model-resolution)=

38
39
40
41
42
43
44
45
46
47
48
49
50
51
### Model resolution

vLLM loads HuggingFace-compatible models by inspecting the `architectures` field in `config.json` of the model repository
and finding the corresponding implementation that is registered to vLLM.
Nevertheless, our model resolution may fail for the following reasons:

- The `config.json` of the model repository lacks the `architectures` field.
- Unofficial repositories refer to a model using alternative names which are not recorded in vLLM.
- The same architecture name is used for multiple models, creating ambiguity as to which model should be loaded.

To fix this, explicitly specify the model architecture by passing `config.json` overrides to the `hf_overrides` option.
For example:

```python
Reid's avatar
Reid committed
52
53
from vllm import LLM

54
55
56
57
58
59
60
61
model = LLM(
    model="cerebras/Cerebras-GPT-1.3B",
    hf_overrides={"architectures": ["GPT2LMHeadModel"]},  # GPT-2
)
```

Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM.

62
63
(reducing-memory-usage)=

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
### Reducing memory usage

Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem.

#### Tensor Parallelism (TP)

Tensor parallelism (`tensor_parallel_size` option) can be used to split the model across multiple GPUs.

The following code splits the model across 2 GPUs.

```python
llm = LLM(model="ibm-granite/granite-3.1-8b-instruct",
          tensor_parallel_size=2)
```

79
:::{important}
80
81
82
83
To ensure that vLLM initializes CUDA correctly, you should avoid calling related functions (e.g. {func}`torch.cuda.set_device`)
before initializing vLLM. Otherwise, you may run into an error like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.

To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable.
84
:::
85

86
87
88
89
90
91
:::{note}
With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism).

You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
:::

92
93
94
95
96
97
98
99
100
101
102
#### Quantization

Quantized models take less memory at the cost of lower precision.

Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Neural Magic](https://huggingface.co/neuralmagic))
and used directly without extra configuration.

Dynamic quantization is also supported via the `quantization` option -- see [here](#quantization-index) for more details.

#### Context length and batch size

103
You can further reduce memory usage by limiting the context length of the model (`max_model_len` option)
104
105
106
and the maximum batch size (`max_num_seqs` option).

```python
Reid's avatar
Reid committed
107
108
from vllm import LLM

109
110
111
112
113
llm = LLM(model="adept/fuyu-8b",
          max_model_len=2048,
          max_num_seqs=2)
```

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#### Reduce CUDA Graphs

By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU.

:::{important}
CUDA graph capture takes up more memory in V1 than in V0.
:::

You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:

```python
from vllm import LLM
from vllm.config import CompilationConfig, CompilationLevel

llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        # By default, it goes up to max_num_seqs
        cudagraph_capture_sizes=[1, 2, 4, 8, 16],
    ),
)
```

You can disable graph capturing completely via the `enforce_eager` flag:

```python
from vllm import LLM

llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
          enforce_eager=True)
```

147
148
149
150
151
152
153
#### Adjust cache size

If you run out of CPU RAM, try the following options:

- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).

154
#### Multi-modal input limits
155

156
157
158
159
160
161
162
163
164
You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model:

```python
from vllm import LLM

# Accept up to 3 images and 1 video per prompt
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
          limit_mm_per_prompt={"image": 3, "video": 1})
```
165

166
You can go a step further and disable unused modalities completely by setting its limit to zero.
167
168
169
170
171
For example, if your application only accepts image input, there is no need to allocate any memory for videos.

```python
from vllm import LLM

172
# Accept any number of images but no videos
173
174
175
176
177
178
179
180
181
182
183
184
185
186
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
          limit_mm_per_prompt={"video": 0})
```

You can even run a multi-modal model for text-only inference:

```python
from vllm import LLM

# Don't accept images. Just text.
llm = LLM(model="google/gemma-3-27b-it",
          limit_mm_per_prompt={"image": 0})
```

187
188
189
190
### Performance optimization and tuning

You can potentially improve the performance of vLLM by finetuning various options.
Please refer to [this guide](#optimization-and-tuning) for more details.