Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19332c04
Unverified
Commit
19332c04
authored
Sep 09, 2025
by
wang.yuqi
Committed by
GitHub
Sep 09, 2025
Browse files
[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
a55cf41a
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
166 additions
and
61 deletions
+166
-61
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+31
-6
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
+1
-0
vllm/config/__init__.py
vllm/config/__init__.py
+52
-1
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+23
-15
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+4
-6
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+3
-1
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+8
-8
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+6
-4
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+11
-8
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+1
-1
vllm/model_executor/models/jina_vl.py
vllm/model_executor/models/jina_vl.py
+8
-5
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+3
-1
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+4
-0
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+11
-5
No files found.
tests/models/language/pooling/mteb_utils.py
View file @
19332c04
...
@@ -9,6 +9,7 @@ import mteb
...
@@ -9,6 +9,7 @@ import mteb
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
requests
import
requests
import
torch
from
tests.models.utils
import
(
EmbedModelInfo
,
RerankModelInfo
,
from
tests.models.utils
import
(
EmbedModelInfo
,
RerankModelInfo
,
check_embeddings_close
)
check_embeddings_close
)
...
@@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
...
@@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs
=
None
,
vllm_extra_kwargs
=
None
,
hf_model_callback
=
None
,
hf_model_callback
=
None
,
atol
=
MTEB_EMBED_TOL
):
atol
=
MTEB_EMBED_TOL
):
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# A model family has many models with the same architecture,
# and we don't need to test each one.
# and we don't need to test each one.
if
not
model_info
.
enable_test
:
pytest
.
skip
(
"Skipping test."
)
pytest
.
skip
(
"Skipping test."
)
example_prompts
=
[
"The chef prepared a delicious meal."
]
# Test embed_dims, isnan and whether to use normalize
example_prompts
=
[
"The chef prepared a delicious meal."
*
1000
]
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
# Allow vllm to test using hf_overrides
if
model_info
.
hf_overrides
is
not
None
:
if
model_info
.
hf_overrides
is
not
None
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
model_info
.
hf_overrides
vllm_extra_kwargs
[
"hf_overrides"
]
=
model_info
.
hf_overrides
...
@@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
...
@@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
# Confirm whether vllm is using the correct architecture
if
model_info
.
architecture
:
if
model_info
.
architecture
:
assert
model_info
.
architecture
in
model_config
.
architectures
assert
model_info
.
architecture
in
model_config
.
architectures
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert
(
model_config
.
_model_info
.
default_pooling_type
==
assert
(
model_config
.
_model_info
.
default_pooling_type
==
model_info
.
default_pooling_type
)
model_info
.
default_pooling_type
)
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
MTEB_EMBED_TASKS
)
MTEB_EMBED_TASKS
)
vllm_dtype
=
vllm_model
.
llm
.
llm_engine
.
model_config
.
dtype
vllm_dtype
=
vllm_model
.
llm
.
llm_engine
.
model_config
.
dtype
vllm_outputs
=
vllm_model
.
embed
(
example_prompts
)
# Test embed_dims, isnan and whether to use normalize
vllm_outputs
=
vllm_model
.
embed
(
example_prompts
,
truncate_prompt_tokens
=-
1
)
assert
not
torch
.
any
(
torch
.
isnan
(
torch
.
tensor
(
vllm_outputs
)))
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if
model_info
.
mteb_score
is
None
:
if
model_info
.
mteb_score
is
None
:
with
hf_runner
(
model_info
.
name
,
with
hf_runner
(
model_info
.
name
,
is_sentence_transformer
=
True
,
is_sentence_transformer
=
True
,
dtype
=
"float32"
)
as
hf_model
:
dtype
=
"float32"
)
as
hf_model
:
# e.g. setting default parameters for the encode method of hf_runner
if
hf_model_callback
is
not
None
:
if
hf_model_callback
is
not
None
:
hf_model_callback
(
hf_model
)
hf_model_callback
(
hf_model
)
...
@@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
...
@@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
hf_model_callback
=
None
,
hf_model_callback
=
None
,
vllm_mteb_encoder
=
VllmMtebEncoder
,
vllm_mteb_encoder
=
VllmMtebEncoder
,
atol
=
MTEB_RERANK_TOL
):
atol
=
MTEB_RERANK_TOL
):
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# A model family has many models with the same architecture,
# and we don't need to test each one.
# and we don't need to test each one.
if
not
model_info
.
enable_test
:
pytest
.
skip
(
"Skipping test."
)
pytest
.
skip
(
"Skipping test."
)
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
# Allow vllm to test using hf_overrides
if
model_info
.
hf_overrides
is
not
None
:
if
model_info
.
hf_overrides
is
not
None
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
model_info
.
hf_overrides
vllm_extra_kwargs
[
"hf_overrides"
]
=
model_info
.
hf_overrides
...
@@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
...
@@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
# Confirm whether vllm is using the correct architecture
if
model_info
.
architecture
:
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
model_config
.
architectures
)
assert
(
model_info
.
architecture
in
model_config
.
architectures
)
# Score API is only enabled for num_labels == 1
assert
model_config
.
hf_config
.
num_labels
==
1
assert
model_config
.
hf_config
.
num_labels
==
1
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert
(
model_config
.
_model_info
.
default_pooling_type
==
assert
(
model_config
.
_model_info
.
default_pooling_type
==
model_info
.
default_pooling_type
)
model_info
.
default_pooling_type
)
...
@@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
...
@@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
languages
=
MTEB_RERANK_LANGS
)
languages
=
MTEB_RERANK_LANGS
)
vllm_dtype
=
model_config
.
dtype
vllm_dtype
=
model_config
.
dtype
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if
model_info
.
mteb_score
is
None
:
if
model_info
.
mteb_score
is
None
:
st_main_score
,
st_dtype
=
mteb_test_rerank_models_hf
(
st_main_score
,
st_dtype
=
mteb_test_rerank_models_hf
(
hf_runner
,
model_info
.
name
,
hf_model_callback
)
hf_runner
,
model_info
.
name
,
hf_model_callback
)
...
...
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
View file @
19332c04
...
@@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
...
@@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS
=
[
RERANK_MODELS
=
[
LASTPoolingRerankModelInfo
(
"BAAI/bge-reranker-v2-gemma"
,
LASTPoolingRerankModelInfo
(
"BAAI/bge-reranker-v2-gemma"
,
architecture
=
"GemmaForSequenceClassification"
,
architecture
=
"GemmaForSequenceClassification"
,
mteb_score
=
0.33757
,
hf_overrides
=
{
hf_overrides
=
{
"architectures"
:
"architectures"
:
[
"GemmaForSequenceClassification"
],
[
"GemmaForSequenceClassification"
],
...
...
vllm/config/__init__.py
View file @
19332c04
...
@@ -745,7 +745,7 @@ class ModelConfig:
...
@@ -745,7 +745,7 @@ class ModelConfig:
self
.
pooler_config
=
self
.
_init_pooler_config
()
self
.
pooler_config
=
self
.
_init_pooler_config
()
self
.
dtype
=
_get_and_verify_dtype
(
self
.
dtype
:
torch
.
dtype
=
_get_and_verify_dtype
(
self
.
model
,
self
.
model
,
self
.
hf_config
,
self
.
hf_config
,
self
.
dtype
,
self
.
dtype
,
...
@@ -1751,6 +1751,32 @@ class ModelConfig:
...
@@ -1751,6 +1751,32 @@ class ModelConfig:
# `llm as reranker` models defaults to not using pad_token.
# `llm as reranker` models defaults to not using pad_token.
return
getattr
(
self
.
hf_config
,
"use_pad_token"
,
True
)
return
getattr
(
self
.
hf_config
,
"use_pad_token"
,
True
)
@
property
def
head_dtype
(
self
)
->
torch
.
dtype
:
"""
"head" refers to the last Linear layer(s) of an LLM,
such as the lm_head in a generation model,
or the score or classifier in a classification model.
The default head_dtype based on runner_type.
\n
- The pooling model defaults to using fp32 head,
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.
\n
- The generate model defaults to not using fp32 head,
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
"""
head_dtype
=
_get_head_dtype
(
config
=
self
.
hf_config
,
dtype
=
self
.
dtype
,
runner_type
=
self
.
runner_type
)
if
head_dtype
not
in
current_platform
.
supported_dtypes
:
logger
.
warning_once
(
"The current platform does not support [%s] head dtype, "
"fallback to model dtype [%s]."
,
head_dtype
,
self
.
dtype
)
return
self
.
dtype
logger
.
debug_once
(
"head dtype: %s"
,
head_dtype
)
return
head_dtype
def
get_and_verify_max_len
(
self
,
max_model_len
:
int
):
def
get_and_verify_max_len
(
self
,
max_model_len
:
int
):
# Consider max_model_len in tokenizer_config only when
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
# pooling models use absolute position_embedding.
...
@@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
...
@@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
return
torch_dtype
return
torch_dtype
def
_get_head_dtype
(
config
:
PretrainedConfig
,
dtype
:
torch
.
dtype
,
runner_type
:
str
)
->
torch
.
dtype
:
head_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
getattr
(
config
,
"head_dtype"
,
None
)
if
head_dtype
==
"model"
:
return
dtype
elif
isinstance
(
head_dtype
,
str
):
head_dtype
=
head_dtype
.
lower
()
if
head_dtype
not
in
_STR_DTYPE_TO_TORCH_DTYPE
:
raise
ValueError
(
f
"Unknown dtype:
{
head_dtype
!
r
}
"
)
return
_STR_DTYPE_TO_TORCH_DTYPE
[
head_dtype
]
elif
isinstance
(
head_dtype
,
torch
.
dtype
):
return
head_dtype
elif
head_dtype
is
None
:
if
torch
.
float32
not
in
current_platform
.
supported_dtypes
:
return
dtype
if
runner_type
==
"pooling"
:
return
torch
.
float32
return
dtype
else
:
raise
ValueError
(
f
"Unknown dtype:
{
head_dtype
}
"
)
def
_get_and_verify_max_len
(
def
_get_and_verify_max_len
(
hf_config
:
PretrainedConfig
,
hf_config
:
PretrainedConfig
,
tokenizer_config
:
Optional
[
dict
],
tokenizer_config
:
Optional
[
dict
],
...
...
vllm/model_executor/layers/pooler.py
View file @
19332c04
...
@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
...
@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
enum
import
IntEnum
from
itertools
import
groupby
from
itertools
import
groupby
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
,
cast
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
...
@@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
class
PoolerNormalize
(
PoolerActivation
):
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
normalize
(
pooled_data
.
float
(),
p
=
2
,
dim
=-
1
)
return
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
return
x
.
to
(
pooled_data
.
dtype
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
)
class
PoolerClassify
(
PoolerActivation
):
class
PoolerClassify
(
PoolerActivation
):
...
@@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
...
@@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
pooled_data
.
shape
[
-
1
])
pooled_data
.
shape
[
-
1
])
if
num_labels
<
2
:
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
)
return
F
.
softmax
(
pooled_data
.
float
()
,
dim
=-
1
)
.
to
(
pooled_data
.
dtype
)
return
F
.
softmax
(
pooled_data
,
dim
=-
1
)
class
LambdaPoolerActivation
(
PoolerActivation
):
class
LambdaPoolerActivation
(
PoolerActivation
):
...
@@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.model_executor.models.adapters
import
_load_st_projector
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
_load_st_projector
(
self
.
projector
:
Optional
[
nn
.
Module
]
=
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
vllm_config
.
model_config
)
if
vllm_config
else
None
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
pooling_metadata
:
PoolingMetadata
):
...
@@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
# Apply ST projector
if
self
.
projector
is
not
None
:
if
self
.
projector
is
not
None
:
projector
=
cast
(
nn
.
Module
,
self
.
projector
)
pooled_data
=
self
.
projector
(
pooled_data
)
def
_proj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
y
=
projector
(
x
.
to
(
torch
.
float32
))
return
y
.
to
(
orig_dtype
)
pooled_data
=
_proj
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params
=
get_pooling_params
(
pooling_metadata
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
...
@@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
...
@@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
pooling_metadata
:
PoolingMetadata
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
[
p
.
to
(
self
.
head_dtype
)
for
p
in
pooled_data
]
else
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
# for softmax
# for softmax
...
@@ -641,6 +646,7 @@ class ClassifierPooler(Pooler):
...
@@ -641,6 +646,7 @@ class ClassifierPooler(Pooler):
self
.
act_fn
=
act_fn
or
PoolerClassify
()
self
.
act_fn
=
act_fn
or
PoolerClassify
()
self
.
logit_bias
:
Optional
[
self
.
logit_bias
:
Optional
[
float
]
=
vllm_config
.
model_config
.
pooler_config
.
logit_bias
float
]
=
vllm_config
.
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
return
{
"classify"
,
"score"
}
...
@@ -655,6 +661,8 @@ class ClassifierPooler(Pooler):
...
@@ -655,6 +661,8 @@ class ClassifierPooler(Pooler):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
# pooled_data shape: [batchsize, num_labels]
...
...
vllm/model_executor/models/adapters.py
View file @
19332c04
...
@@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
...
@@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
linear
=
nn
.
Linear
(
layer_config
.
get
(
"in_features"
,
768
),
linear
=
nn
.
Linear
(
layer_config
.
get
(
"in_features"
,
768
),
layer_config
.
get
(
"out_features"
,
768
),
layer_config
.
get
(
"out_features"
,
768
),
bias
=
layer_config
.
get
(
"bias"
,
True
),
bias
=
layer_config
.
get
(
"bias"
,
True
),
dtype
=
torch
.
float32
)
dtype
=
model_config
.
head_dtype
)
if
not
_load_dense_weights
(
linear
,
folder
,
model_config
):
if
not
_load_dense_weights
(
linear
,
folder
,
model_config
):
continue
continue
...
@@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
...
@@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
layers
.
append
(
linear
)
layers
.
append
(
linear
)
if
act_name
:
=
layer_config
.
get
(
"activation_function"
):
if
act_name
:
=
layer_config
.
get
(
"activation_function"
):
layers
.
append
(
get_act_fn
(
act_name
))
layers
.
append
(
get_act_fn
(
act_name
))
return
nn
.
Sequential
(
*
layers
).
to
(
dtype
=
torch
.
float32
)
return
nn
.
Sequential
(
*
layers
).
to
(
dtype
=
model_config
.
head_dtype
)
except
Exception
:
except
Exception
:
logger
.
exception
(
"ST projector loading failed"
)
logger
.
exception
(
"ST projector loading failed"
)
...
@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
...
@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
if
weight_key
in
state_dict
:
if
weight_key
in
state_dict
:
weight_loader
=
getattr
(
linear
.
weight
,
"weight_loader"
,
weight_loader
=
getattr
(
linear
.
weight
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
linear
.
weight
,
weight_loader
(
linear
.
weight
,
state_dict
[
weight_key
])
state_dict
[
weight_key
].
to
(
torch
.
float32
))
bias_key
=
weight_key
.
replace
(
"weight"
,
"bias"
)
bias_key
=
weight_key
.
replace
(
"weight"
,
"bias"
)
if
linear
.
bias
is
not
None
and
bias_key
in
state_dict
:
if
linear
.
bias
is
not
None
and
bias_key
in
state_dict
:
bias_loader
=
getattr
(
linear
.
bias
,
"weight_loader"
,
bias_loader
=
getattr
(
linear
.
bias
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
bias_loader
(
linear
.
bias
,
bias_loader
(
linear
.
bias
,
state_dict
[
bias_key
])
state_dict
[
bias_key
].
to
(
torch
.
float32
))
return
True
return
True
except
Exception
:
except
Exception
:
logger
.
exception
(
"Failed to load %s"
,
filename
)
logger
.
exception
(
"Failed to load %s"
,
filename
)
...
...
vllm/model_executor/models/bert.py
View file @
19332c04
...
@@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
self
.
bert
=
BertPoolingModel
(
vllm_config
=
vllm_config
,
self
.
bert
=
BertPoolingModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"bert"
),
prefix
=
maybe_prefix
(
prefix
,
"bert"
),
embedding_class
=
BertEmbedding
)
embedding_class
=
BertEmbedding
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
,
dtype
=
vllm_config
.
model_config
.
head_dtype
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
19332c04
...
@@ -637,13 +637,13 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -637,13 +637,13 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
new
=
GteNewModel
(
vllm_config
=
vllm_config
,
self
.
new
=
GteNewModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
prefix
=
prefix
,
add_pooling_layer
=
True
)
add_pooling_layer
=
True
)
self
.
classifier
=
RowParallelLinear
(
config
.
hidden_size
,
self
.
classifier
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_labels
,
config
.
num_labels
,
input_is_parallel
=
False
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
params_dtype
=
vllm_config
.
model_config
.
head_dtype
,
prefix
,
"classifier"
),
prefix
=
maybe_prefix
(
prefix
,
"classifier"
),
return_bias
=
False
)
return_bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
...
...
vllm/model_executor/models/gpt2.py
View file @
19332c04
...
@@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
...
@@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
transformer
=
GPT2Model
(
vllm_config
=
vllm_config
,
self
.
transformer
=
GPT2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"gpt2"
))
prefix
=
maybe_prefix
(
prefix
,
"gpt2"
))
self
.
score
=
nn
.
Linear
(
config
.
n_embd
,
config
.
num_labels
,
bias
=
False
)
self
.
score
=
nn
.
Linear
(
config
.
n_embd
,
config
.
num_labels
,
bias
=
False
,
dtype
=
vllm_config
.
model_config
.
head_dtype
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
...
@@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
...
@@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
"encode"
:
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
Pooler
.
for_encode
(
pooler_config
),
"classify"
:
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
Non
e
),
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
scor
e
),
})
})
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
@@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
...
@@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
position_ids
=
positions
,
position_ids
=
positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
)
logits
=
self
.
score
(
hidden_states
)
return
hidden_states
return
logits
def
_add_transformer_prefix
(
def
_add_transformer_prefix
(
...
...
vllm/model_executor/models/internlm2.py
View file @
19332c04
...
@@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
...
@@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
delattr
(
self
,
attr
)
delattr
(
self
,
attr
)
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
v_head
=
RowParallelLinear
(
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
config
.
hidden_size
,
self
.
v_head
=
RowParallelLinear
(
config
.
hidden_size
,
1
,
1
,
bias
=
False
,
bias
=
False
,
input_is_parallel
=
False
,
input_is_parallel
=
False
,
params_dtype
=
self
.
head_dtype
,
prefix
=
maybe_prefix
(
prefix
,
"v_head"
),
prefix
=
maybe_prefix
(
prefix
,
"v_head"
),
)
return_bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
...
@@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
...
@@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
hidden_states
=
hidden_states
.
to
(
self
.
head_dtype
)
logits
=
self
.
v_head
(
hidden_states
)
return
logits
return
logits
vllm/model_executor/models/jamba.py
View file @
19332c04
...
@@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
...
@@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
config
.
hidden_size
,
config
.
hidden_size
,
num_labels
,
num_labels
,
bias
=
score_bias
,
bias
=
score_bias
,
dtype
=
torch
.
float32
,
dtype
=
vllm_config
.
model_config
.
head_dtype
,
)
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
...
...
vllm/model_executor/models/jina_vl.py
View file @
19332c04
...
@@ -5,9 +5,9 @@ from typing import Optional
...
@@ -5,9 +5,9 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.inputs
import
TokensPrompt
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -28,13 +28,17 @@ logger = init_logger(__name__)
...
@@ -28,13 +28,17 @@ logger = init_logger(__name__)
class
JinaVLScorer
(
nn
.
Module
):
class
JinaVLScorer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Pretrained
Config
):
def
__init__
(
self
,
model_
config
:
"Model
Config
"
):
super
().
__init__
()
super
().
__init__
()
config
=
model_config
.
hf_config
head_dtype
=
model_config
.
head_dtype
self
.
dense
=
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
dense
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
params_dtype
=
head_dtype
,
bias
=
True
)
bias
=
True
)
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
config
.
num_labels
,
params_dtype
=
head_dtype
,
bias
=
True
)
bias
=
True
)
def
forward
(
self
,
x
,
**
kwargs
):
def
forward
(
self
,
x
,
**
kwargs
):
...
@@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
...
@@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"qwen2_vl"
))
prefix
=
maybe_prefix
(
prefix
,
"qwen2_vl"
))
config
=
vllm_config
.
model_config
.
hf_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
score
=
JinaVLScorer
(
config
)
self
.
score
=
JinaVLScorer
(
vllm_config
.
model_
config
)
self
.
pooler
=
DispatchPooler
({
self
.
pooler
=
DispatchPooler
({
"encode"
:
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
Pooler
.
for_encode
(
pooler_config
),
...
...
vllm/model_executor/models/modernbert.py
View file @
19332c04
...
@@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
config
=
config
self
.
config
=
config
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
))
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
))
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
,
dtype
=
vllm_config
.
model_config
.
head_dtype
)
self
.
pooling
=
ModernBertPooler
(
config
)
self
.
pooling
=
ModernBertPooler
(
config
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
19332c04
...
@@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
score
=
nn
.
Sequential
(
self
.
score
=
nn
.
Sequential
(
ColumnParallelLinear
(
config
.
hidden_size
,
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
params_dtype
=
self
.
head_dtype
,
return_bias
=
False
),
return_bias
=
False
),
nn
.
ReLU
(),
nn
.
ReLU
(),
RowParallelLinear
(
config
.
hidden_size
,
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
config
.
num_labels
,
params_dtype
=
self
.
head_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
return_bias
=
False
),
return_bias
=
False
),
)
)
...
@@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
hidden_states
=
hidden_states
.
to
(
self
.
head_dtype
)
logits
=
self
.
score
(
hidden_states
)
logits
=
self
.
score
(
hidden_states
)
return
logits
return
logits
...
...
vllm/model_executor/models/roberta.py
View file @
19332c04
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
RobertaConfig
from
transformers
import
RobertaConfig
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
CLSPool
,
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
CLSPool
,
DispatchPooler
,
Pooler
)
DispatchPooler
,
Pooler
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
...
@@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
class
RobertaClassificationHead
(
nn
.
Module
):
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
"""Head for sentence-level classification tasks."""
def
__init__
(
self
,
config
:
Roberta
Config
):
def
__init__
(
self
,
model_
config
:
"Model
Config
"
):
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
config
=
model_config
.
hf_config
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
head_dtype
=
model_config
.
head_dtype
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
head_dtype
)
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
,
dtype
=
head_dtype
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# CLSPool has already been applied in `pooling`
# CLSPool has already been applied in `pooling`
...
@@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
roberta
=
BertModel
(
vllm_config
=
vllm_config
,
self
.
roberta
=
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"bert"
),
prefix
=
maybe_prefix
(
prefix
,
"bert"
),
embedding_class
=
RobertaEmbedding
)
embedding_class
=
RobertaEmbedding
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
classifier
=
RobertaClassificationHead
(
vllm_config
.
model_
config
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment