Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
65f311ce
Unverified
Commit
65f311ce
authored
Jul 30, 2025
by
wang.yuqi
Committed by
GitHub
Jul 29, 2025
Browse files
[Frontend] Add LLM.reward specific to reward models (#21720)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
1b0a1555
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
174 additions
and
35 deletions
+174
-35
docs/models/pooling_models.md
docs/models/pooling_models.md
+53
-28
examples/offline_inference/basic/embed.py
examples/offline_inference/basic/embed.py
+1
-2
examples/offline_inference/basic/reward.py
examples/offline_inference/basic/reward.py
+53
-0
tests/conftest.py
tests/conftest.py
+4
-0
tests/models/language/pooling/test_reward.py
tests/models/language/pooling/test_reward.py
+1
-1
tests/models/language/pooling/test_truncation_control.py
tests/models/language/pooling/test_truncation_control.py
+3
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+59
-1
No files found.
docs/models/pooling_models.md
View file @
65f311ce
...
...
@@ -45,14 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
[
Pooler.get_supported_tasks
][
vllm.model_executor.layers.pooler.Pooler.get_supported_tasks
]
,
enabling the corresponding APIs:
| Task | APIs |
|------------|--------------------|
|
`encode`
|
`
encode`
|
|
`embed`
|
`embed
`
,
`score
`
\*
|
|
`classify`
|
`classify
`
|
|
`score`
|
`score
`
|
| Task | APIs
|
|------------|--------------------
------------------
|
|
`encode`
|
`
LLM.reward(...)`
|
|
`embed`
|
`
LLM.
embed
(...)`
,
`LLM.score(...)
`
\*
|
|
`classify`
|
`
LLM.
classify
(...)`
|
|
`score`
|
`
LLM.
score
(...)`
|
\*
The
`score`
API falls back to
`embed`
task if the model does not support
`score`
task.
\*
The
`
LLM.
score
(...)
`
API falls back to
`embed`
task if the model does not support
`score`
task.
### Pooler Configuration
...
...
@@ -66,11 +66,11 @@ you can override some of its attributes via the `--override-pooler-config` optio
If the model has been converted via
`--convert`
(see above),
the pooler assigned to each task has the following attributes by default:
| Task | Pooling Type
| Normalization | Softmax |
|------------|--------------
--
|---------------|---------|
|
`
encode
`
|
`ALL`
| ❌ | ❌
|
|
`embed`
|
`LAST`
| ✅︎ | ❌ |
|
`classify`
|
`LAST`
| ❌ | ✅︎ |
| Task | Pooling Type | Normalization | Softmax |
|------------|--------------|---------------|---------|
|
`
reward
`
|
`ALL`
| ❌ | ❌ |
|
`embed`
|
`LAST`
| ✅︎ | ❌ |
|
`classify`
|
`LAST`
| ❌ | ✅︎ |
When loading
[
Sentence Transformers
](
https://huggingface.co/sentence-transformers
)
models,
its Sentence Transformers configuration file (
`modules.json`
) takes priority over the model's defaults.
...
...
@@ -83,21 +83,6 @@ which takes priority over both the model's and Sentence Transformers's defaults.
The
[
LLM
][
vllm.LLM
]
class provides various methods for offline inference.
See
[
configuration
][
configuration
]
for a list of options when initializing the model.
### `LLM.encode`
The
[
encode
][
vllm.LLM.encode
]
method is available to all pooling models in vLLM.
It returns the extracted hidden states directly, which is useful for reward models.
```
python
from
vllm
import
LLM
llm
=
LLM
(
model
=
"Qwen/Qwen2.5-Math-RM-72B"
,
runner
=
"pooling"
)
(
output
,)
=
llm
.
encode
(
"Hello, my name is"
)
data
=
output
.
outputs
.
data
print
(
f
"Data:
{
data
!
r
}
"
)
```
### `LLM.embed`
The
[
embed
][
vllm.LLM.embed
]
method outputs an embedding vector for each prompt.
...
...
@@ -106,7 +91,7 @@ It is primarily designed for embedding models.
```
python
from
vllm
import
LLM
llm
=
LLM
(
model
=
"intfloat/e5-
mistral-7b-instruct
"
,
runner
=
"pooling"
)
llm
=
LLM
(
model
=
"intfloat/e5-
small
"
,
runner
=
"pooling"
)
(
output
,)
=
llm
.
embed
(
"Hello, my name is"
)
embeds
=
output
.
outputs
.
embedding
...
...
@@ -154,6 +139,46 @@ print(f"Score: {score}")
A code example can be found here:
<gh-file:examples
/
offline_inference
/
basic
/
score.py
>
### `LLM.reward`
The
[
reward
][
vllm.LLM.reward
]
method is available to all reward models in vLLM.
It returns the extracted hidden states directly.
```
python
from
vllm
import
LLM
llm
=
LLM
(
model
=
"internlm/internlm2-1_8b-reward"
,
runner
=
"pooling"
,
trust_remote_code
=
True
)
(
output
,)
=
llm
.
reward
(
"Hello, my name is"
)
data
=
output
.
outputs
.
data
print
(
f
"Data:
{
data
!
r
}
"
)
```
A code example can be found here:
<gh-file:examples
/
offline_inference
/
basic
/
reward.py
>
### `LLM.encode`
The
[
encode
][
vllm.LLM.encode
]
method is available to all pooling models in vLLM.
It returns the extracted hidden states directly.
!!! note
Please use one of the more specific methods or set the task directly when using
`LLM.encode`
:
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
- For similarity scores, use `LLM.score(...)`.
```
python
from
vllm
import
LLM
llm
=
LLM
(
model
=
"intfloat/e5-small"
,
runner
=
"pooling"
)
(
output
,)
=
llm
.
encode
(
"Hello, my name is"
,
pooling_task
=
"embed"
)
data
=
output
.
outputs
.
data
print
(
f
"Data:
{
data
!
r
}
"
)
```
## Online Serving
Our
[
OpenAI-Compatible Server
](
../serving/openai_compatible_server.md
)
provides endpoints that correspond to the offline APIs:
...
...
examples/offline_inference/basic/embed.py
View file @
65f311ce
...
...
@@ -12,10 +12,9 @@ def parse_args():
parser
=
EngineArgs
.
add_cli_args
(
parser
)
# Set example specific arguments
parser
.
set_defaults
(
model
=
"intfloat/e5-
mistral-7b-instruct
"
,
model
=
"intfloat/e5-
small
"
,
runner
=
"pooling"
,
enforce_eager
=
True
,
max_model_len
=
1024
,
)
return
parser
.
parse_args
()
...
...
examples/offline_inference/basic/reward.py
0 → 100644
View file @
65f311ce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
argparse
import
Namespace
from
vllm
import
LLM
,
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
def
parse_args
():
parser
=
FlexibleArgumentParser
()
parser
=
EngineArgs
.
add_cli_args
(
parser
)
# Set example specific arguments
parser
.
set_defaults
(
model
=
"internlm/internlm2-1_8b-reward"
,
runner
=
"pooling"
,
enforce_eager
=
True
,
max_model_len
=
1024
,
trust_remote_code
=
True
,
)
return
parser
.
parse_args
()
def
main
(
args
:
Namespace
):
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create an LLM.
# You should pass runner="pooling" for reward models
llm
=
LLM
(
**
vars
(
args
))
# Generate rewards. The output is a list of PoolingRequestOutput.
outputs
=
llm
.
reward
(
prompts
)
# Print the outputs.
print
(
"
\n
Generated Outputs:
\n
"
+
"-"
*
60
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
rewards
=
output
.
outputs
.
data
rewards_trimmed
=
(
(
str
(
rewards
[:
16
])[:
-
1
]
+
", ...]"
)
if
len
(
rewards
)
>
16
else
rewards
)
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Reward:
{
rewards_trimmed
}
(size=
{
len
(
rewards
)
}
)"
)
print
(
"-"
*
60
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
main
(
args
)
tests/conftest.py
View file @
65f311ce
...
...
@@ -1053,6 +1053,10 @@ class VllmRunner:
req_outputs
=
self
.
llm
.
encode
(
prompts
)
return
[
req_output
.
outputs
.
data
for
req_output
in
req_outputs
]
def
reward
(
self
,
prompts
:
list
[
str
])
->
list
[
list
[
float
]]:
req_outputs
=
self
.
llm
.
reward
(
prompts
)
return
[
req_output
.
outputs
.
data
for
req_output
in
req_outputs
]
def
score
(
self
,
text_1
:
Union
[
str
,
list
[
str
]],
...
...
tests/models/language/pooling/test_reward.py
View file @
65f311ce
...
...
@@ -95,7 +95,7 @@ def test_prm_models(
monkeypatch
.
setenv
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
)
with
vllm_runner
(
model
,
max_model_len
=
1024
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
math_step_prompts
)
vllm_outputs
=
vllm_model
.
reward
(
math_step_prompts
)
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModel
)
as
hf_model
:
hf_model
=
step_reward_patch_hf_model
(
hf_model
)
...
...
tests/models/language/pooling/test_truncation_control.py
View file @
65f311ce
...
...
@@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner,
with
vllm_runner
(
model_name
,
runner
=
"pooling"
,
max_model_len
=
max_model_len
)
as
vllm_model
:
vllm_output
=
vllm_model
.
llm
.
e
ncode
(
vllm_output
=
vllm_model
.
llm
.
e
mbed
(
input_str
,
truncate_prompt_tokens
=
truncate_prompt_tokens
)
prompt_tokens
=
vllm_output
[
0
].
prompt_token_ids
...
...
@@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner,
with
vllm_runner
(
model_name
,
runner
=
"pooling"
,
max_model_len
=
max_model_len
)
as
vllm_model
:
vllm_output
=
vllm_model
.
llm
.
e
ncode
(
vllm_output
=
vllm_model
.
llm
.
e
mbed
(
input_str
,
truncate_prompt_tokens
=
truncate_prompt_tokens
)
prompt_tokens
=
vllm_output
[
0
].
prompt_token_ids
...
...
@@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner,
model_name
,
runner
=
"pooling"
,
max_model_len
=
max_model_len
)
as
vllm_model
:
llm_output
=
vllm_model
.
llm
.
e
ncode
(
llm_output
=
vllm_model
.
llm
.
e
mbed
(
input_str
,
truncate_prompt_tokens
=
truncate_prompt_tokens
)
assert
llm_output
==
f
"""truncate_prompt_tokens value
...
...
vllm/entrypoints/llm.py
View file @
65f311ce
...
...
@@ -1037,7 +1037,7 @@ class LLM:
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
pooling_task
:
Optional
[
PoolingTask
]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
PoolingRequestOutput
]:
"""Apply pooling to the hidden states corresponding to the input
...
...
@@ -1069,6 +1069,25 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
if
pooling_task
is
None
:
if
"embed"
in
self
.
supported_tasks
:
pooling_task
=
"embed"
else
:
pooling_task
=
"encode"
logger
.
warning_once
(
"`LLM.encode` is currently using `pooling_task = %s`.
\n
"
"Please use one of the more specific methods or set the "
"task directly when using `LLM.encode`:
\n
"
" - For embeddings, use `LLM.embed(...)` "
"or `pooling_task=
\"
embed
\"
`.
\n
"
" - For classification logits, use `LLM.classify(...)` "
"or `pooling_task=
\"
classify
\"
`.
\n
"
" - For rewards, use `LLM.reward(...)` "
"or `pooling_task=
\"
reward
\"
`
\n
"
" - For similarity scores, use `LLM.score(...)`."
,
pooling_task
)
model_config
=
self
.
llm_engine
.
model_config
runner_type
=
model_config
.
runner_type
if
runner_type
!=
"pooling"
:
...
...
@@ -1207,6 +1226,45 @@ class LLM:
return
[
ClassificationRequestOutput
.
from_base
(
item
)
for
item
in
items
]
def
reward
(
self
,
prompts
:
Union
[
PromptType
,
Sequence
[
PromptType
]],
/
,
*
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
)
->
list
[
PoolingRequestOutput
]:
"""
Generate rewards for each prompt.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
Returns:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
"""
return
self
.
encode
(
prompts
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
pooling_params
=
pooling_params
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
pooling_task
=
"encode"
,
)
def
_embedding_score
(
self
,
tokenizer
:
AnyTokenizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment