"vscode:/vscode.git/clone" did not exist on "4719460644b4629db2b6dbf12be331d0b34b4b6f"
pooling_models.md 14 KB
Newer Older
1
# Pooling Models
2

3
vLLM also supports pooling models, such as embedding, classification, and reward models.
4

5
In vLLM, pooling models implement the [VllmModelForPooling][vllm.model_executor.models.VllmModelForPooling] interface.
6
These models use a [Pooler][vllm.model_executor.layers.pooler.Pooler] to extract the final hidden states of the input
7
8
before returning them.

9
!!! note
10
    We currently support pooling models primarily for convenience. This is not guaranteed to provide any performance improvements over using Hugging Face Transformers or Sentence Transformers directly.
11

12
    We plan to optimize pooling models in vLLM. Please comment on <https://github.com/vllm-project/vllm/issues/21796> if you have any suggestions!
13

14
## Configuration
15

16
### Model Runner
17

18
Run a model in pooling mode via the option `--runner pooling`.
19

20
21
!!! tip
    There is no need to set this option in the vast majority of cases as vLLM can automatically
22
    detect the appropriate model runner via `--runner auto`.
23
24
25
26
27
28
29
30
31
32

### Model Conversion

vLLM can adapt models for various pooling tasks via the option `--convert <type>`.

If `--runner pooling` has been set (manually or automatically) but the model does not implement the
[VllmModelForPooling][vllm.model_executor.models.VllmModelForPooling] interface,
vLLM will attempt to automatically convert the model according to the architecture names
shown in the table below.

33
34
35
36
37
| Architecture                                    | `--convert` | Supported pooling tasks               |
|-------------------------------------------------|-------------|---------------------------------------|
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed`     | `token_embed`, `embed`                |
| `*For*Classification`, `*ClassificationModel`   | `classify`  | `token_classify`, `classify`, `score` |
| `*ForRewardModeling`, `*RewardModel`            | `reward`    | `token_classify`                      |
38
39
40
41
42
43
44
45
46

!!! tip
    You can explicitly set `--convert <type>` to specify how to convert the model.

### Pooling Tasks

Each pooling model in vLLM supports one or more of these tasks according to
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
enabling the corresponding APIs:
47

48
49
50
51
52
53
54
55
| Task             | APIs                                                                          |
|------------------|-------------------------------------------------------------------------------|
| `embed`          | `LLM.embed(...)`, `LLM.score(...)`\*, `LLM.encode(..., pooling_task="embed")` |
| `classify`       | `LLM.classify(...)`, `LLM.encode(..., pooling_task="classify")`               |
| `score`          | `LLM.score(...)`                                                              |
| `token_classify` | `LLM.reward(...)`, `LLM.encode(..., pooling_task="token_classify")`           |
| `token_embed`    | `LLM.encode(..., pooling_task="token_embed")`                                 |
| `plugin`         | `LLM.encode(..., pooling_task="plugin")`                                      |
56

57
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
58

59
### Pooler Configuration
60

61
62
63
#### Predefined models

If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`,
64
you can override some of its attributes via the `--pooler-config` option.
65
66
67
68
69

#### Converted models

If the model has been converted via `--convert` (see above),
the pooler assigned to each task has the following attributes by default:
70

71
72
73
74
75
| Task       | Pooling Type | Normalization | Softmax |
|------------|--------------|---------------|---------|
| `reward`   | `ALL`        | ❌            | ❌     |
| `embed`    | `LAST`       | ✅︎            | ❌      |
| `classify` | `LAST`       | ❌            | ✅︎      |
76

77
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
78
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
79

80
You can further customize this via the `--pooler-config` option,
81
which takes priority over both the model's and Sentence Transformers' defaults.
82

83
84
## Offline Inference

85
The [LLM][vllm.LLM] class provides various methods for offline inference.
86
See [configuration](../api/README.md#configuration) for a list of options when initializing the model.
87
88
89

### `LLM.embed`

90
The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt.
91
92
93
It is primarily designed for embedding models.

```python
Reid's avatar
Reid committed
94
95
from vllm import LLM

96
llm = LLM(model="intfloat/e5-small", runner="pooling")
97
98
99
100
101
102
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```

103
A code example can be found here: [examples/offline_inference/basic/embed.py](../../examples/offline_inference/basic/embed.py)
104
105
106

### `LLM.classify`

107
The [classify][vllm.LLM.classify] method outputs a probability vector for each prompt.
108
109
110
It is primarily designed for classification models.

```python
Reid's avatar
Reid committed
111
112
from vllm import LLM

113
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", runner="pooling")
114
115
116
117
118
119
(output,) = llm.classify("Hello, my name is")

probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
```

120
A code example can be found here: [examples/offline_inference/basic/classify.py](../../examples/offline_inference/basic/classify.py)
121
122
123

### `LLM.score`

124
The [score][vllm.LLM.score] method outputs similarity scores between sentence pairs.
125
It is designed for embedding models and cross-encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
126

127
128
129
!!! note
    vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
    To handle RAG at a higher level, you should use integration frameworks such as [LangChain](https://github.com/langchain-ai/langchain).
130
131

```python
Reid's avatar
Reid committed
132
133
from vllm import LLM

134
llm = LLM(model="BAAI/bge-reranker-v2-m3", runner="pooling")
135
136
137
138
(output,) = llm.score(
    "What is the capital of France?",
    "The capital of Brazil is Brasilia.",
)
139
140
141
142
143

score = output.outputs.score
print(f"Score: {score}")
```

144
A code example can be found here: [examples/offline_inference/basic/score.py](../../examples/offline_inference/basic/score.py)
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
### `LLM.reward`

The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.

```python
from vllm import LLM

llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True)
(output,) = llm.reward("Hello, my name is")

data = output.outputs.data
print(f"Data: {data!r}")
```

160
A code example can be found here: [examples/offline_inference/basic/reward.py](../../examples/offline_inference/basic/reward.py)
161
162
163
164
165
166
167
168
169
170

### `LLM.encode`

The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.

!!! note
    Please use one of the more specific methods or set the task directly when using `LLM.encode`:

    - For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
    - For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
171
    - For similarity scores, use `LLM.score(...)`.
172
173
    - For rewards, use `LLM.reward(...)` or `pooling_task="token_classify"`.
    - For token classification, use `pooling_task="token_classify"`.
174
175
    - For multi-vector retrieval, use `pooling_task="token_embed"`.
    - For IO Processor Plugins, use `pooling_task="plugin"`.
176
177
178
179
180
181
182
183
184
185
186

```python
from vllm import LLM

llm = LLM(model="intfloat/e5-small", runner="pooling")
(output,) = llm.encode("Hello, my name is", pooling_task="embed")

data = output.outputs.data
print(f"Data: {data!r}")
```

187
## Online Serving
188

189
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
190

191
192
193
- [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models.
- [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
- [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models.
194
195
196
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.

!!! note
197
    Please use one of the more specific endpoints or set the task directly when using the [Pooling API](../serving/openai_compatible_server.md#pooling-api):
198
199

    - For embeddings, use [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) or `"task":"embed"`.
200
201
202
203
204
205
    - For classification logits, use [Classification API](../serving/openai_compatible_server.md#classification-api) or `"task":"classify"`.
    - For similarity scores, use [Score API](../serving/openai_compatible_server.md#score-api).
    - For rewards, use `"task":"token_classify"`.
    - For token classification, use `"task":"token_classify"`.
    - For multi-vector retrieval, use `"task":"token_embed"`.
    - For IO Processor Plugins, use `"task":"plugin"`.
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

```python
# start a supported embeddings model server with `vllm serve`, e.g.
# vllm serve intfloat/e5-small
import requests

host = "localhost"
port = "8000"
model_name = "intfloat/e5-small"

api_url = f"http://{host}:{port}/pooling"

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
prompt = {"model": model_name, "input": prompts, "task": "embed"}

response = requests.post(api_url, json=prompt)

for output in response.json()["data"]:
    data = output["data"]
    print(f"Data: {data!r} (size={len(data)})")
```
232
233
234

## Matryoshka Embeddings

235
[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows users to trade off between performance and cost.
236

237
238
!!! warning
    Not all embedding models are trained using Matryoshka Representation Learning. To avoid misuse of the `dimensions` parameter, vLLM returns an error for requests that attempt to change the output dimension of models that do not support Matryoshka Embeddings.
239

240
    For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model will result in the following error.
241

242
243
244
    ```json
    {"object":"error","message":"Model \"BAAI/bge-m3\" does not support matryoshka representation, changing output dimensions will lead to poor results.","type":"BadRequestError","param":null,"code":400}
    ```
245
246
247

### Manually enable Matryoshka Embeddings

248
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json`, you can change the output dimension to arbitrary values. Use `matryoshka_dimensions` to control the allowed output dimensions.
249

250
For models that support Matryoshka Embeddings but are not recognized by vLLM, manually override the config using `hf_overrides={"is_matryoshka": True}` or `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline), or `--hf-overrides '{"is_matryoshka": true}'` or `--hf-overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'` (online).
251
252
253

Here is an example to serve a model with Matryoshka Embeddings enabled.

254
```bash
255
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}'
256
257
258
259
```

### Offline Inference

260
You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in [PoolingParams][vllm.PoolingParams].
261
262
263
264

```python
from vllm import LLM, PoolingParams

265
266
267
268
269
270
271
272
273
llm = LLM(
    model="jinaai/jina-embeddings-v3",
    runner="pooling",
    trust_remote_code=True,
)
outputs = llm.embed(
    ["Follow the white rabbit."],
    pooling_params=PoolingParams(dimensions=32),
)
274
275
276
print(outputs[0].outputs)
```

277
A code example can be found here: [examples/offline_inference/pooling/embed_matryoshka_fy.py](../../examples/offline_inference/pooling/embed_matryoshka_fy.py)
278
279
280

### Online Inference

281
Use the following command to start the vLLM server.
282

283
```bash
284
285
286
287
288
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
```

You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter.

289
```bash
290
291
292
293
294
295
296
curl http://127.0.0.1:8000/v1/embeddings \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
    "input": "Follow the white rabbit.",
    "model": "jinaai/jina-embeddings-v3",
    "encoding_format": "float",
297
    "dimensions": 32
298
299
300
301
302
303
  }'
```

Expected output:

```json
304
{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
305
306
```

307
An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py)
308
309
310
311
312

## Deprecated Features

### Encode task

313
We have split the `encode` task into two more specific token-wise tasks: `token_embed` and `token_classify`:
314

315
316
- `token_embed` is the same as `embed`, using normalization as the activation.
- `token_classify` is the same as `classify`, by default using softmax as the activation.
317
318
319

### Remove softmax from PoolingParams

320
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.