Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19332c04
Unverified
Commit
19332c04
authored
Sep 09, 2025
by
wang.yuqi
Committed by
GitHub
Sep 09, 2025
Browse files
[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
a55cf41a
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
166 additions
and
61 deletions
+166
-61
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+31
-6
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
+1
-0
vllm/config/__init__.py
vllm/config/__init__.py
+52
-1
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+23
-15
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+4
-6
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+3
-1
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+8
-8
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+6
-4
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+11
-8
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+1
-1
vllm/model_executor/models/jina_vl.py
vllm/model_executor/models/jina_vl.py
+8
-5
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+3
-1
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+4
-0
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+11
-5
No files found.
tests/models/language/pooling/mteb_utils.py
View file @
19332c04
...
...
@@ -9,6 +9,7 @@ import mteb
import
numpy
as
np
import
pytest
import
requests
import
torch
from
tests.models.utils
import
(
EmbedModelInfo
,
RerankModelInfo
,
check_embeddings_close
)
...
...
@@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs
=
None
,
hf_model_callback
=
None
,
atol
=
MTEB_EMBED_TOL
):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest
.
skip
(
"Skipping test."
)
example_prompts
=
[
"The chef prepared a delicious meal."
]
# Test embed_dims, isnan and whether to use normalize
example_prompts
=
[
"The chef prepared a delicious meal."
*
1000
]
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
# Allow vllm to test using hf_overrides
if
model_info
.
hf_overrides
is
not
None
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
model_info
.
hf_overrides
...
...
@@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
# Confirm whether vllm is using the correct architecture
if
model_info
.
architecture
:
assert
model_info
.
architecture
in
model_config
.
architectures
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert
(
model_config
.
_model_info
.
default_pooling_type
==
model_info
.
default_pooling_type
)
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
MTEB_EMBED_TASKS
)
vllm_dtype
=
vllm_model
.
llm
.
llm_engine
.
model_config
.
dtype
vllm_outputs
=
vllm_model
.
embed
(
example_prompts
)
# Test embed_dims, isnan and whether to use normalize
vllm_outputs
=
vllm_model
.
embed
(
example_prompts
,
truncate_prompt_tokens
=-
1
)
assert
not
torch
.
any
(
torch
.
isnan
(
torch
.
tensor
(
vllm_outputs
)))
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if
model_info
.
mteb_score
is
None
:
with
hf_runner
(
model_info
.
name
,
is_sentence_transformer
=
True
,
dtype
=
"float32"
)
as
hf_model
:
# e.g. setting default parameters for the encode method of hf_runner
if
hf_model_callback
is
not
None
:
hf_model_callback
(
hf_model
)
...
...
@@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
hf_model_callback
=
None
,
vllm_mteb_encoder
=
VllmMtebEncoder
,
atol
=
MTEB_RERANK_TOL
):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if
not
model_info
.
enable_test
:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest
.
skip
(
"Skipping test."
)
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs
=
vllm_extra_kwargs
or
{}
vllm_extra_kwargs
[
"dtype"
]
=
model_info
.
dtype
# Allow vllm to test using hf_overrides
if
model_info
.
hf_overrides
is
not
None
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
model_info
.
hf_overrides
...
...
@@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
# Confirm whether vllm is using the correct architecture
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
model_config
.
architectures
)
# Score API is only enabled for num_labels == 1
assert
model_config
.
hf_config
.
num_labels
==
1
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert
(
model_config
.
_model_info
.
default_pooling_type
==
model_info
.
default_pooling_type
)
...
...
@@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
languages
=
MTEB_RERANK_LANGS
)
vllm_dtype
=
model_config
.
dtype
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if
model_info
.
mteb_score
is
None
:
st_main_score
,
st_dtype
=
mteb_test_rerank_models_hf
(
hf_runner
,
model_info
.
name
,
hf_model_callback
)
...
...
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
View file @
19332c04
...
...
@@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS
=
[
LASTPoolingRerankModelInfo
(
"BAAI/bge-reranker-v2-gemma"
,
architecture
=
"GemmaForSequenceClassification"
,
mteb_score
=
0.33757
,
hf_overrides
=
{
"architectures"
:
[
"GemmaForSequenceClassification"
],
...
...
vllm/config/__init__.py
View file @
19332c04
...
...
@@ -745,7 +745,7 @@ class ModelConfig:
self
.
pooler_config
=
self
.
_init_pooler_config
()
self
.
dtype
=
_get_and_verify_dtype
(
self
.
dtype
:
torch
.
dtype
=
_get_and_verify_dtype
(
self
.
model
,
self
.
hf_config
,
self
.
dtype
,
...
...
@@ -1751,6 +1751,32 @@ class ModelConfig:
# `llm as reranker` models defaults to not using pad_token.
return
getattr
(
self
.
hf_config
,
"use_pad_token"
,
True
)
@
property
def
head_dtype
(
self
)
->
torch
.
dtype
:
"""
"head" refers to the last Linear layer(s) of an LLM,
such as the lm_head in a generation model,
or the score or classifier in a classification model.
The default head_dtype based on runner_type.
\n
- The pooling model defaults to using fp32 head,
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.
\n
- The generate model defaults to not using fp32 head,
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
"""
head_dtype
=
_get_head_dtype
(
config
=
self
.
hf_config
,
dtype
=
self
.
dtype
,
runner_type
=
self
.
runner_type
)
if
head_dtype
not
in
current_platform
.
supported_dtypes
:
logger
.
warning_once
(
"The current platform does not support [%s] head dtype, "
"fallback to model dtype [%s]."
,
head_dtype
,
self
.
dtype
)
return
self
.
dtype
logger
.
debug_once
(
"head dtype: %s"
,
head_dtype
)
return
head_dtype
def
get_and_verify_max_len
(
self
,
max_model_len
:
int
):
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
...
...
@@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
return
torch_dtype
def
_get_head_dtype
(
config
:
PretrainedConfig
,
dtype
:
torch
.
dtype
,
runner_type
:
str
)
->
torch
.
dtype
:
head_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
getattr
(
config
,
"head_dtype"
,
None
)
if
head_dtype
==
"model"
:
return
dtype
elif
isinstance
(
head_dtype
,
str
):
head_dtype
=
head_dtype
.
lower
()
if
head_dtype
not
in
_STR_DTYPE_TO_TORCH_DTYPE
:
raise
ValueError
(
f
"Unknown dtype:
{
head_dtype
!
r
}
"
)
return
_STR_DTYPE_TO_TORCH_DTYPE
[
head_dtype
]
elif
isinstance
(
head_dtype
,
torch
.
dtype
):
return
head_dtype
elif
head_dtype
is
None
:
if
torch
.
float32
not
in
current_platform
.
supported_dtypes
:
return
dtype
if
runner_type
==
"pooling"
:
return
torch
.
float32
return
dtype
else
:
raise
ValueError
(
f
"Unknown dtype:
{
head_dtype
}
"
)
def
_get_and_verify_max_len
(
hf_config
:
PretrainedConfig
,
tokenizer_config
:
Optional
[
dict
],
...
...
vllm/model_executor/layers/pooler.py
View file @
19332c04
...
...
@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
itertools
import
groupby
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
,
cast
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
normalize
(
pooled_data
.
float
(),
p
=
2
,
dim
=-
1
)
return
x
.
to
(
pooled_data
.
dtype
)
return
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
)
class
PoolerClassify
(
PoolerActivation
):
...
...
@@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
pooled_data
.
shape
[
-
1
])
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
)
return
F
.
softmax
(
pooled_data
.
float
()
,
dim
=-
1
)
.
to
(
pooled_data
.
dtype
)
return
F
.
softmax
(
pooled_data
,
dim
=-
1
)
class
LambdaPoolerActivation
(
PoolerActivation
):
...
...
@@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
from
vllm.model_executor.models.adapters
import
_load_st_projector
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
_load_st_projector
(
self
.
projector
:
Optional
[
nn
.
Module
]
=
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
...
...
@@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
if
self
.
projector
is
not
None
:
projector
=
cast
(
nn
.
Module
,
self
.
projector
)
def
_proj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
y
=
projector
(
x
.
to
(
torch
.
float32
))
return
y
.
to
(
orig_dtype
)
pooled_data
=
_proj
(
pooled_data
)
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params
=
get_pooling_params
(
pooling_metadata
)
...
...
@@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
[
p
.
to
(
self
.
head_dtype
)
for
p
in
pooled_data
]
else
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
# for softmax
...
...
@@ -641,6 +646,7 @@ class ClassifierPooler(Pooler):
self
.
act_fn
=
act_fn
or
PoolerClassify
()
self
.
logit_bias
:
Optional
[
float
]
=
vllm_config
.
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
...
...
@@ -655,6 +661,8 @@ class ClassifierPooler(Pooler):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
...
...
vllm/model_executor/models/adapters.py
View file @
19332c04
...
...
@@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
linear
=
nn
.
Linear
(
layer_config
.
get
(
"in_features"
,
768
),
layer_config
.
get
(
"out_features"
,
768
),
bias
=
layer_config
.
get
(
"bias"
,
True
),
dtype
=
torch
.
float32
)
dtype
=
model_config
.
head_dtype
)
if
not
_load_dense_weights
(
linear
,
folder
,
model_config
):
continue
...
...
@@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
layers
.
append
(
linear
)
if
act_name
:
=
layer_config
.
get
(
"activation_function"
):
layers
.
append
(
get_act_fn
(
act_name
))
return
nn
.
Sequential
(
*
layers
).
to
(
dtype
=
torch
.
float32
)
return
nn
.
Sequential
(
*
layers
).
to
(
dtype
=
model_config
.
head_dtype
)
except
Exception
:
logger
.
exception
(
"ST projector loading failed"
)
...
...
@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
if
weight_key
in
state_dict
:
weight_loader
=
getattr
(
linear
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
linear
.
weight
,
state_dict
[
weight_key
].
to
(
torch
.
float32
))
weight_loader
(
linear
.
weight
,
state_dict
[
weight_key
])
bias_key
=
weight_key
.
replace
(
"weight"
,
"bias"
)
if
linear
.
bias
is
not
None
and
bias_key
in
state_dict
:
bias_loader
=
getattr
(
linear
.
bias
,
"weight_loader"
,
default_weight_loader
)
bias_loader
(
linear
.
bias
,
state_dict
[
bias_key
].
to
(
torch
.
float32
))
bias_loader
(
linear
.
bias
,
state_dict
[
bias_key
])
return
True
except
Exception
:
logger
.
exception
(
"Failed to load %s"
,
filename
)
...
...
vllm/model_executor/models/bert.py
View file @
19332c04
...
...
@@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
self
.
bert
=
BertPoolingModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"bert"
),
embedding_class
=
BertEmbedding
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
,
dtype
=
vllm_config
.
model_config
.
head_dtype
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
19332c04
...
...
@@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
new
=
GteNewModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
add_pooling_layer
=
True
)
self
.
classifier
=
R
owParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
input_is_parallel
=
False
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"classifier"
),
return_bias
=
False
)
self
.
classifier
=
R
eplicatedLinear
(
config
.
hidden_size
,
config
.
num_labels
,
bias
=
True
,
quant_config
=
quant_config
,
params_dtype
=
vllm_config
.
model_config
.
head_dtype
,
prefix
=
maybe_prefix
(
prefix
,
"classifier"
),
return_bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
...
...
vllm/model_executor/models/gpt2.py
View file @
19332c04
...
...
@@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
config
=
vllm_config
.
model_config
.
hf_config
self
.
transformer
=
GPT2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"gpt2"
))
self
.
score
=
nn
.
Linear
(
config
.
n_embd
,
config
.
num_labels
,
bias
=
False
)
self
.
score
=
nn
.
Linear
(
config
.
n_embd
,
config
.
num_labels
,
bias
=
False
,
dtype
=
vllm_config
.
model_config
.
head_dtype
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
...
...
@@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
Non
e
),
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
scor
e
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
@@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
position_ids
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
logits
=
self
.
score
(
hidden_states
)
return
logits
return
hidden_states
def
_add_transformer_prefix
(
...
...
vllm/model_executor/models/internlm2.py
View file @
19332c04
...
...
@@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
delattr
(
self
,
attr
)
config
=
vllm_config
.
model_config
.
hf_config
self
.
v_head
=
RowParallelLinear
(
config
.
hidden_size
,
1
,
bias
=
False
,
input_is_parallel
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"v_head"
),
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
v_head
=
RowParallelLinear
(
config
.
hidden_size
,
1
,
bias
=
False
,
input_is_parallel
=
False
,
params_dtype
=
self
.
head_dtype
,
prefix
=
maybe_prefix
(
prefix
,
"v_head"
),
return_bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
...
...
@@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
hidden_states
=
hidden_states
.
to
(
self
.
head_dtype
)
logits
=
self
.
v_head
(
hidden_states
)
return
logits
vllm/model_executor/models/jamba.py
View file @
19332c04
...
...
@@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
config
.
hidden_size
,
num_labels
,
bias
=
score_bias
,
dtype
=
torch
.
float32
,
dtype
=
vllm_config
.
model_config
.
head_dtype
,
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
...
...
vllm/model_executor/models/jina_vl.py
View file @
19332c04
...
...
@@ -5,9 +5,9 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -28,13 +28,17 @@ logger = init_logger(__name__)
class
JinaVLScorer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Pretrained
Config
):
def
__init__
(
self
,
model_
config
:
"Model
Config
"
):
super
().
__init__
()
config
=
model_config
.
hf_config
head_dtype
=
model_config
.
head_dtype
self
.
dense
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
params_dtype
=
head_dtype
,
bias
=
True
)
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
params_dtype
=
head_dtype
,
bias
=
True
)
def
forward
(
self
,
x
,
**
kwargs
):
...
...
@@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"qwen2_vl"
))
config
=
vllm_config
.
model_config
.
hf_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
score
=
JinaVLScorer
(
config
)
self
.
score
=
JinaVLScorer
(
vllm_config
.
model_
config
)
self
.
pooler
=
DispatchPooler
({
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
...
...
vllm/model_executor/models/modernbert.py
View file @
19332c04
...
...
@@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
config
=
config
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
))
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
,
dtype
=
vllm_config
.
model_config
.
head_dtype
)
self
.
pooling
=
ModernBertPooler
(
config
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
19332c04
...
...
@@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
score
=
nn
.
Sequential
(
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
params_dtype
=
self
.
head_dtype
,
return_bias
=
False
),
nn
.
ReLU
(),
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
params_dtype
=
self
.
head_dtype
,
quant_config
=
quant_config
,
return_bias
=
False
),
)
...
...
@@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
hidden_states
.
to
(
self
.
head_dtype
)
logits
=
self
.
score
(
hidden_states
)
return
logits
...
...
vllm/model_executor/models/roberta.py
View file @
19332c04
...
...
@@ -8,7 +8,7 @@ import torch
from
torch
import
nn
from
transformers
import
RobertaConfig
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
CLSPool
,
DispatchPooler
,
Pooler
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
def
__init__
(
self
,
config
:
Roberta
Config
):
def
__init__
(
self
,
model_
config
:
"Model
Config
"
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
config
=
model_config
.
hf_config
head_dtype
=
model_config
.
head_dtype
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
head_dtype
)
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
,
dtype
=
head_dtype
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# CLSPool has already been applied in `pooling`
...
...
@@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
roberta
=
BertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"bert"
),
embedding_class
=
RobertaEmbedding
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
classifier
=
RobertaClassificationHead
(
vllm_config
.
model_
config
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment