arch_overview.md 10.5 KB
Newer Older
1
2
3
4
---
title: Architecture Overview
---
[](){ #arch-overview }
5
6
7

This document provides an overview of the vLLM architecture.

8
[TOC]
9

10
## Entrypoints
11
12
13
14

vLLM provides a number of entrypoints for interacting with the system. The
following diagram shows the relationship between them.

15
![Entrypoints Diagram](../assets/design/arch_overview/entrypoints.excalidraw.png)
16

17
### LLM Class
18
19
20
21
22
23
24

The LLM class provides the primary Python interface for doing offline inference,
which is interacting with a model without using a separate model inference
server.

Here is a sample of `LLM` class usage:

25
26
```python
from vllm import LLM, SamplingParams
27

28
29
30
31
32
33
# Define a list of input prompts
prompts = [
    "Hello, my name is",
    "The capital of France is",
    "The largest ocean is",
]
34

35
36
# Define sampling parameters
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
37

38
39
# Initialize the LLM engine with the OPT-125M model
llm = LLM(model="facebook/opt-125m")
40

41
42
# Generate outputs for the input prompts
outputs = llm.generate(prompts, sampling_params)
43

44
45
46
47
48
49
# Print the generated outputs
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
50

H's avatar
H committed
51
More API details can be found in the [Offline Inference](#offline-inference-api) section of the API docs.
52

53
The code for the `LLM` class can be found in <gh-file:vllm/entrypoints/llm.py>.
54

55
### OpenAI-Compatible API Server
56
57
58
59

The second primary interface to vLLM is via its OpenAI-compatible API server.
This server can be started using the `vllm serve` command.

60
61
62
```bash
vllm serve <model>
```
63

64
The code for the `vllm` CLI can be found in <gh-file:vllm/entrypoints/cli/main.py>.
65
66
67
68

Sometimes you may see the API server entrypoint used directly instead of via the
`vllm` CLI command. For example:

69
70
71
```bash
python -m vllm.entrypoints.openai.api_server --model <model>
```
72

73
That code can be found in <gh-file:vllm/entrypoints/openai/api_server.py>.
74

75
More details on the API server can be found in the [OpenAI-Compatible Server][openai-compatible-server] document.
76

77
## LLM Engine
78
79
80
81

The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of
the vLLM system, handling model inference and asynchronous request processing.

82
![LLMEngine Diagram](../assets/design/arch_overview/llm_engine.excalidraw.png)
83

84
### LLMEngine
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

The `LLMEngine` class is the core component of the vLLM engine. It is
responsible for receiving requests from clients and generating outputs from the
model. The `LLMEngine` includes input processing, model execution (possibly
distributed across multiple hosts and/or GPUs), scheduling, and output
processing.

- **Input Processing**: Handles tokenization of input text using the specified
  tokenizer.
- **Scheduling**: Chooses which requests are processed in each step.
- **Model Execution**: Manages the execution of the language model, including
  distributed execution across multiple GPUs.
- **Output Processing**: Processes the outputs generated by the model, decoding the
  token IDs from a language model into human-readable text.

100
The code for `LLMEngine` can be found in <gh-file:vllm/engine/llm_engine.py>.
101

102
### AsyncLLMEngine
103
104
105
106
107
108
109

The `AsyncLLMEngine` class is an asynchronous wrapper for the `LLMEngine` class.
It uses `asyncio` to create a background loop that continuously processes
incoming requests. The `AsyncLLMEngine` is designed for online serving, where it
can handle multiple concurrent requests and stream outputs to clients.

The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo
110
API server that serves as a simpler example in <gh-file:vllm/entrypoints/api_server.py>.
111

112
The code for `AsyncLLMEngine` can be found in <gh-file:vllm/engine/async_llm_engine.py>.
113

114
## Worker
115
116
117
118
119

A worker is a process that runs the model inference. vLLM follows the common
practice of using one process to control one accelerator device, such as GPUs.
For example, if we use tensor parallelism of size 2 and pipeline parallelism of
size 2, we will have 4 workers in total. Workers are identified by their
120
121
`rank` and `local_rank`. `rank` is used for global orchestration, while
`local_rank` is mainly used for assigning the accelerator device and accessing
122
123
local resources such as the file system and shared memory.

124
## Model Runner
125
126
127
128
129

Every worker has one model runner object, responsible for loading and running
the model. Much of the model execution logic resides here, such as preparing
input tensors and capturing cudagraphs.

130
## Model
131
132

Every model runner object has one model object, which is the actual
133
`torch.nn.Module` instance. See [huggingface_integration][huggingface-integration] for how various
134
135
configurations affect the class we ultimately get.

136
## Class Hierarchy
137
138
139

The following figure shows the class hierarchy of vLLM:

140
141
142
> <figure markdown="span">
>   ![](../assets/design/hierarchy.png){ align="center" alt="query" width="100%" }
> </figure>
143
144
145

There are several important design choices behind this class hierarchy:

146
147
1\. **Extensibility**: All classes in the hierarchy accept a configuration object
containing all the necessary information. The [VllmConfig](https://github.com/vllm-project/vllm/blob/d1c6799b8870e513bf4f2305cbf6cda9fc3d773b/vllm/config.py#L2036)
148
149
150
151
152
153
154
155
156
157
158
159
class is the main configuration object that is passed around. The class
hierarchy is quite deep, and every class needs to read the configuration it is
interested in. By encapsulating all configurations in one object, we can easily
pass the configuration object around and access the configuration we need.
Suppose we want to add a new feature (this is often the case given how fast the
field of LLM inference is evolving) that only touches the model runner. We will
have to add a new configuration option in the `VllmConfig` class. Since we pass
the whole config object around, we only need to add the configuration option to
the `VllmConfig` class, and the model runner can access it directly. We don't
need to change the constructor of the engine, worker, or model class to pass the
new configuration option.

160
2\. **Uniformity**: The model runner needs a unified interface to create and
161
162
163
164
165
166
167
168
169
170
171
initialize the model. vLLM supports more than 50 types of popular open-source
models. Each model has its own initialization logic. If the constructor
signature varies with models, the model runner does not know how to call the
constructor accordingly, without complicated and error-prone inspection logic.
By making the constructor of the model class uniform, the model runner can
easily create and initialize the model without knowing the specific model type.
This is also useful for composing models. Vision-language models often consist
of a vision model and a language model. By making the constructor uniform, we
can easily create a vision model and a language model and compose them into a
vision-language model.

172
173
!!! note
    To support this change, all vLLM models' signatures have been updated to:
174

175
    ```python
176
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    ```

    To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one:

    ```python
    class MyOldModel(nn.Module):
        def __init__(
            self,
            config,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            lora_config: Optional[LoRAConfig] = None,
            prefix: str = "",
        ) -> None:
            ...

    from vllm.config import VllmConfig
    class MyNewModel(MyOldModel):
        def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
            config = vllm_config.model_config.hf_config
            cache_config = vllm_config.cache_config
            quant_config = vllm_config.quant_config
            lora_config = vllm_config.lora_config
            super().__init__(config, cache_config, quant_config, lora_config, prefix)

    if __version__ >= "0.6.4":
        MyModel = MyNewModel
    else:
        MyModel = MyOldModel
    ```

    This way, the model can work with both old and new versions of vLLM.
209
210

3\. **Sharding and Quantization at Initialization**: Certain features require
211
212
213
214
215
216
217
218
219
220
221
222
changing the model weights. For example, tensor parallelism needs to shard the
model weights, and quantization needs to quantize the model weights. There are
two possible ways to implement this feature. One way is to change the model
weights after the model is initialized. The other way is to change the model
weights during the model initialization. vLLM chooses the latter. The first
approach is not scalable to large models. Suppose we want to run a 405B model
(with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should
only load 50GB weights. If we change the model weights after the model is
initialized, we need to load the full 810GB weights to every GPU and then shard
the weights, leading to a huge memory overhead. Instead, if we shard the weights
during the model initialization, every layer will only create a shard of the
weights it needs, leading to a much smaller memory overhead. The same idea
223
applies to quantization. Note that we also add an additional argument `prefix`
224
225
to the model's constructor so that the model can initialize itself differently
based on the prefix. This is useful for non-uniform quantization, where
226
227
228
different parts of the model are quantized differently. The `prefix` is
usually an empty string for the top-level model and a string like `"vision"`
or `"language"` for the sub-models. In general, it matches the name of the
229
230
231
232
233
234
module's state dict in the checkpoint file.

One disadvantage of this design is that it is hard to write unit tests for
individual components in vLLM because every component needs to be initialized by
a complete config object. We solve this problem by providing a default
initialization function that creates a default config object with all fields set
235
to `None`. If the component we want to test only cares about a few fields in
236
237
238
239
240
the config object, we can create a default config object and set the fields we
care about. This way, we can test the component in isolation. Note that many
tests in vLLM are end-to-end tests that test the whole system, so this is not a
big problem.

241
In summary, the complete config object `VllmConfig` can be treated as an
242
engine-level global state that is shared among all vLLM classes.