Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2e26f915
Unverified
Commit
2e26f915
authored
Jul 04, 2025
by
wang.yuqi
Committed by
GitHub
Jul 04, 2025
Browse files
[Model][3/N] Automatic conversion of CrossEncoding model (#20168)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
9e5452ee
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
234 additions
and
133 deletions
+234
-133
docs/models/supported_models.md
docs/models/supported_models.md
+15
-6
tests/models/language/pooling/test_embedding.py
tests/models/language/pooling/test_embedding.py
+9
-1
tests/models/language/pooling/test_gte.py
tests/models/language/pooling/test_gte.py
+12
-4
tests/models/language/pooling/test_mxbai_rerank.py
tests/models/language/pooling/test_mxbai_rerank.py
+84
-0
vllm/config.py
vllm/config.py
+10
-3
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+99
-3
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+1
-1
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+4
-115
No files found.
docs/models/supported_models.md
View file @
2e26f915
...
...
@@ -477,12 +477,20 @@ If your model is not in the above list, we will try to automatically convert the
Specified using
`--task score`
.
| Architecture | Models | Example HF Models |
[
V1
](
gh-issue:8779
)
|
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------|
|
`BertForSequenceClassification`
| BERT-based |
`cross-encoder/ms-marco-MiniLM-L-6-v2`
, etc. | |
|
`Qwen3ForSequenceClassification`
| Qwen3-based |
`tomaarsen/Qwen3-Reranker-0.6B-seq-cls`
,
`Qwen/Qwen3-Reranker-0.6B`
(see note), etc. | ✅︎ |
|
`RobertaForSequenceClassification`
| RoBERTa-based |
`cross-encoder/quora-roberta-base`
, etc. | |
|
`XLMRobertaForSequenceClassification`
| XLM-RoBERTa-based |
`BAAI/bge-reranker-v2-m3`
, etc. | |
| Architecture | Models | Example HF Models |
[
V1
](
gh-issue:8779
)
|
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|---------------------|
|
`BertForSequenceClassification`
| BERT-based |
`cross-encoder/ms-marco-MiniLM-L-6-v2`
, etc. | |
|
`Qwen2ForSequenceClassification`
| Qwen2-based |
`mixedbread-ai/mxbai-rerank-base-v2`
(see note), etc. | ✅︎ |
|
`Qwen3ForSequenceClassification`
| Qwen3-based |
`tomaarsen/Qwen3-Reranker-0.6B-seq-cls`
,
`Qwen/Qwen3-Reranker-0.6B`
(see note), etc. | ✅︎ |
|
`RobertaForSequenceClassification`
| RoBERTa-based |
`cross-encoder/quora-roberta-base`
, etc. | |
|
`XLMRobertaForSequenceClassification`
| XLM-RoBERTa-based |
`BAAI/bge-reranker-v2-m3`
, etc. | |
!!! note
Load the official original
`mxbai-rerank-v2`
by using the following command.
```bash
vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}'
```
!!! note
Load the official original
`Qwen3 Reranker`
by using the following command. More information can be found at:
<gh-file:examples
/
offline_inference
/
qwen3_reranker.py
>
.
...
...
@@ -490,6 +498,7 @@ Specified using `--task score`.
```bash
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
```
[](
){
#supported-mm-models }
## List of Multimodal Language Models
...
...
tests/models/language/pooling/test_embedding.py
View file @
2e26f915
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
Optional
import
pytest
...
...
@@ -74,6 +75,13 @@ def test_models(
vllm_extra_kwargs
[
"override_pooler_config"
]
=
\
PoolerConfig
(
pooling_type
=
"MEAN"
,
normalize
=
False
)
max_model_len
:
Optional
[
int
]
=
512
if
model
in
[
"sentence-transformers/all-MiniLM-L12-v2"
,
"sentence-transformers/stsb-roberta-base-v2"
]:
max_model_len
=
None
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
...
...
@@ -87,7 +95,7 @@ def test_models(
with
vllm_runner
(
model
,
task
=
"embed"
,
max_model_len
=
512
,
max_model_len
=
max_model_len
,
**
vllm_extra_kwargs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
embed
(
example_prompts
)
...
...
tests/models/language/pooling/test_gte.py
View file @
2e26f915
...
...
@@ -56,10 +56,16 @@ MODELS = [
enable_test
=
False
),
]
V1FlashAttentionImpNotSupported
=
[
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
"Alibaba-NLP/gte-modernbert-base"
]
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_embed_models_mteb
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
)
->
None
:
def
test_embed_models_mteb
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
,
monkeypatch
)
->
None
:
if
model_info
.
name
in
V1FlashAttentionImpNotSupported
:
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
vllm_extra_kwargs
:
dict
[
str
,
Any
]
=
{}
if
model_info
.
architecture
==
"GteNewModel"
:
...
...
@@ -71,8 +77,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_embed_models_correctness
(
hf_runner
,
vllm_runner
,
model_info
:
EmbedModelInfo
,
example_prompts
)
->
None
:
model_info
:
EmbedModelInfo
,
example_prompts
,
monkeypatch
)
->
None
:
if
model_info
.
name
in
V1FlashAttentionImpNotSupported
:
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
vllm_extra_kwargs
:
dict
[
str
,
Any
]
=
{}
if
model_info
.
architecture
==
"GteNewModel"
:
...
...
tests/models/language/pooling/test_mxbai_rerank.py
0 → 100644
View file @
2e26f915
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
pytest
import
torch
from
tests.conftest
import
HfRunner
from
.mteb_utils
import
RerankModelInfo
,
mteb_test_rerank_models
RERANK_MODELS
=
[
RerankModelInfo
(
"mixedbread-ai/mxbai-rerank-base-v2"
,
architecture
=
"Qwen2ForSequenceClassification"
,
dtype
=
"float32"
,
enable_test
=
True
),
RerankModelInfo
(
"mixedbread-ai/mxbai-rerank-large-v2"
,
architecture
=
"Qwen2ForSequenceClassification"
,
dtype
=
"float32"
,
enable_test
=
False
)
]
class
MxbaiRerankerHfRunner
(
HfRunner
):
def
__init__
(
self
,
model_name
:
str
,
dtype
:
str
=
"auto"
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
None
:
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
super
().
__init__
(
model_name
,
dtype
,
auto_cls
=
AutoModelForCausalLM
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
'left'
)
self
.
yes_loc
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"1"
)
self
.
no_loc
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"0"
)
def
predict
(
self
,
prompts
:
list
[
list
[
str
]],
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
process_inputs
(
pairs
):
inputs
=
self
.
tokenizer
(
pairs
,
padding
=
False
,
truncation
=
'longest_first'
,
return_attention_mask
=
False
)
for
i
,
ele
in
enumerate
(
inputs
[
'input_ids'
]):
inputs
[
'input_ids'
][
i
]
=
ele
inputs
=
self
.
tokenizer
.
pad
(
inputs
,
padding
=
True
,
return_tensors
=
"pt"
)
for
key
in
inputs
:
inputs
[
key
]
=
inputs
[
key
].
to
(
self
.
model
.
device
)
return
inputs
@
torch
.
no_grad
()
def
compute_logits
(
inputs
):
logits
=
self
.
model
(
**
inputs
).
logits
[:,
-
1
,
:]
yes_logits
=
logits
[:,
self
.
yes_loc
]
no_logits
=
logits
[:,
self
.
no_loc
]
logits
=
yes_logits
-
no_logits
scores
=
logits
.
float
().
sigmoid
()
return
scores
scores
=
[]
for
prompt
in
prompts
:
inputs
=
process_inputs
([
prompt
])
score
=
compute_logits
(
inputs
)
scores
.
append
(
score
[
0
].
item
())
return
torch
.
Tensor
(
scores
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
RERANK_MODELS
)
def
test_rerank_models_mteb
(
vllm_runner
,
model_info
:
RerankModelInfo
)
->
None
:
vllm_extra_kwargs
:
dict
[
str
,
Any
]
=
{}
if
model_info
.
architecture
==
"Qwen2ForSequenceClassification"
:
vllm_extra_kwargs
[
"hf_overrides"
]
=
{
"architectures"
:
[
"Qwen2ForSequenceClassification"
],
"classifier_from_token"
:
[
"0"
,
"1"
],
"method"
:
"from_2_way_softmax"
,
}
mteb_test_rerank_models
(
MxbaiRerankerHfRunner
,
vllm_runner
,
model_info
,
vllm_extra_kwargs
)
vllm/config.py
View file @
2e26f915
...
...
@@ -466,6 +466,9 @@ class ModelConfig:
"affect the random state of the Python process that "
"launched vLLM."
,
self
.
seed
)
# Keep set served_model_name before maybe_model_redirect(self.model)
self
.
served_model_name
=
get_served_model_name
(
self
.
model
,
self
.
served_model_name
)
self
.
model
=
maybe_model_redirect
(
self
.
model
)
# The tokenizer is consistent with the model by default.
if
self
.
tokenizer
is
None
:
...
...
@@ -609,8 +612,6 @@ class ModelConfig:
self
.
original_max_model_len
=
self
.
max_model_len
self
.
max_model_len
=
self
.
get_and_verify_max_len
(
self
.
max_model_len
)
self
.
served_model_name
=
get_served_model_name
(
self
.
model
,
self
.
served_model_name
)
self
.
multimodal_config
=
self
.
_init_multimodal_config
()
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
...
...
@@ -1420,7 +1421,7 @@ class ModelConfig:
@
property
def
is_cross_encoder
(
self
)
->
bool
:
return
self
.
registry
.
is_cross_encoder_model
(
self
.
architectures
)
return
self
.
task
==
"classify"
@
property
def
use_mla
(
self
)
->
bool
:
...
...
@@ -4762,6 +4763,12 @@ class VllmConfig:
if
cls
is
not
None
:
cls
.
verify_and_update_config
(
self
)
if
self
.
model_config
.
task
==
"classify"
:
# Maybe convert ForCausalLM into ForSequenceClassification model.
from
vllm.model_executor.models.adapters
import
(
SequenceClassificationConfig
)
SequenceClassificationConfig
.
verify_and_update_config
(
self
)
def
__str__
(
self
):
return
(
f
"model=
{
self
.
model_config
.
model
!
r
}
,"
...
...
vllm/model_executor/models/adapters.py
View file @
2e26f915
...
...
@@ -2,14 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
Union
,
cast
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.models.config
import
VerifyAndUpdateConfig
from
.interfaces_base
import
VllmModelForPooling
,
is_pooling_model
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
PoolingType
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
...
...
@@ -39,7 +42,6 @@ def _create_pooling_model_cls(
default_softmax
:
bool
,
)
->
_T
:
# Lazy import
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolerOutput
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
...
...
@@ -162,7 +164,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return
cls
# Lazy import
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
PoolerOutput
,
PoolingType
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
...
...
@@ -193,6 +194,7 @@ def as_seq_cls_model(cls: _T) -> _T:
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
vllm_config
=
vllm_config
self
.
task
=
vllm_config
.
model_config
.
task
self
.
pooling_type
=
(
vllm_config
.
model_config
.
pooler_config
.
pooling_type
)
...
...
@@ -242,6 +244,17 @@ def as_seq_cls_model(cls: _T) -> _T:
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
method
=
getattr
(
self
.
config
,
"method"
,
None
)
if
tokens
is
None
and
method
is
None
:
return
super
().
load_weights
(
weights
)
else
:
# Online convert ForCausalLM into
# ForSequenceClassification model.
return
seq_cls_model_loader
(
self
,
weights
)
ModelForSequenceClassification
.
__name__
=
\
_get_pooling_model_name
(
cls
.
__name__
,
"ForSequenceClassification"
)
...
...
@@ -277,3 +290,86 @@ def as_reward_model(cls: _T) -> _T:
_get_pooling_model_name
(
cls
.
__name__
,
"ForReward"
)
return
ModelForReward
# type: ignore
class
SequenceClassificationConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
method
=
getattr
(
config
,
"method"
,
None
)
tokens
=
getattr
(
config
,
"classifier_from_token"
,
None
)
if
method
is
None
:
return
assert
tokens
is
not
None
assert
method
in
SEQ_CLS_LOAD_METHODS
,
f
"method
{
method
}
not supported"
if
method
==
"from_2_way_softmax"
:
assert
len
(
tokens
)
==
2
config
.
num_labels
=
1
else
:
config
.
num_labels
=
len
(
tokens
)
def
load_weights_using_from_2_way_softmax
(
model
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
model_config
=
model
.
vllm_config
.
model_config
tokens
=
getattr
(
model
.
config
,
"classifier_from_token"
,
[])
tokens
=
cast
(
list
[
int
],
tokens
)
assert
len
(
tokens
)
==
2
device
=
model
.
score
.
weight
.
device
if
model
.
config
.
tie_word_embeddings
:
model
.
lm_head
=
model
.
model
.
embed_tokens
else
:
model
.
lm_head
=
ParallelLMHead
(
model
.
config
.
vocab_size
,
model
.
config
.
hidden_size
,
quant_config
=
model
.
quant_config
)
loader
=
AutoWeightsLoader
(
model
)
loaded_weights
=
loader
.
load_weights
(
weights
)
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
revision
=
model_config
.
tokenizer_revision
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
trust_remote_code
=
model_config
.
trust_remote_code
)
false_id
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
0
])
true_id
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
1
])
weight
=
model
.
lm_head
.
weight
.
data
[
true_id
].
to
(
device
).
to
(
torch
.
float32
)
-
model
.
lm_head
.
weight
.
data
[
false_id
].
to
(
device
).
to
(
torch
.
float32
)
model
.
score
.
weight
.
data
.
copy_
(
weight
)
del
model
.
lm_head
loaded_weights
.
add
(
"score.weight"
)
loaded_weights
.
discard
(
"lm_head.weight"
)
return
loaded_weights
SEQ_CLS_LOAD_METHODS
=
{
"from_2_way_softmax"
:
load_weights_using_from_2_way_softmax
,
}
def
seq_cls_model_loader
(
model
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# Online convert ForCausalLM into ForSequenceClassification model.
# - from_2_way_softmax:
# - Qwen3ForCausalLM
# - Qwen3-Reranker
# - Qwen2ForCausalLM
# - mxbai-rerank-v2
config
=
model
.
vllm_config
.
model_config
.
hf_config
method
=
getattr
(
config
,
"method"
,
None
)
assert
method
in
SEQ_CLS_LOAD_METHODS
,
f
"method
{
method
}
not supported"
return
SEQ_CLS_LOAD_METHODS
[
method
](
model
,
weights
)
vllm/model_executor/models/config.py
View file @
2e26f915
...
...
@@ -167,7 +167,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
assert
tokens
is
not
None
and
len
(
tokens
)
==
2
,
\
(
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
)
config
.
num_labels
=
1
vllm_
config
.
model_config
.
hf_config
.
method
=
"from_2_way_softmax"
class
SnowflakeGteNewModelConfig
(
VerifyAndUpdateConfig
):
...
...
vllm/model_executor/models/qwen3.py
View file @
2e26f915
...
...
@@ -38,15 +38,14 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsCrossEncoding
,
SupportsLoRA
,
SupportsPP
from
.adapters
import
as_seq_cls_model
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
maybe_prefix
...
...
@@ -323,114 +322,4 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return
loader
.
load_weights
(
weights
)
class
Qwen3ForSequenceClassification
(
nn
.
Module
,
SupportsLoRA
,
SupportsCrossEncoding
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
vllm_config
=
vllm_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
prefix
=
prefix
self
.
model
=
Qwen3Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
score
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
quant_config
=
quant_config
,
input_is_parallel
=
False
,
bias
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"score"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
,
softmax
=
True
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
hidden_states
=
self
.
_pooler
.
extract_states
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
hidden_states
,
list
):
logits
=
[
self
.
score
(
state
)[
0
]
for
state
in
hidden_states
]
else
:
logits
,
_
=
self
.
score
(
hidden_states
)
pooled_data
=
self
.
_pooler
.
head
(
logits
,
pooling_metadata
)
pooled_outputs
=
[
self
.
_pooler
.
build_output
(
data
.
squeeze
(
-
1
))
for
data
in
pooled_data
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
is_original_qwen3_reranker
=
getattr
(
self
.
config
,
"is_original_qwen3_reranker"
,
False
)
if
not
is_original_qwen3_reranker
:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
self
.
load_weights_from_original_qwen3_reranker
(
weights
)
def
load_weights_from_original_qwen3_reranker
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
model_config
=
self
.
vllm_config
.
model_config
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
device
=
self
.
score
.
weight
.
device
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
self
.
prefix
,
"lm_head"
))
loader
=
AutoWeightsLoader
(
self
)
loaded_weights
=
loader
.
load_weights
(
weights
)
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
revision
=
model_config
.
tokenizer_revision
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
trust_remote_code
=
model_config
.
trust_remote_code
)
a
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
0
])
b
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
1
])
weight
=
self
.
lm_head
.
weight
.
data
[
b
].
to
(
device
)
-
self
.
lm_head
.
weight
.
data
[
a
].
to
(
device
)
self
.
score
.
weight
.
data
.
copy_
(
weight
)
del
self
.
lm_head
loaded_weights
.
add
(
"score.weight"
)
loaded_weights
.
discard
(
"lm_head.weight"
)
return
loaded_weights
Qwen3ForSequenceClassification
=
as_seq_cls_model
(
Qwen3ForCausalLM
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment