Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cd4cfee6
Unverified
Commit
cd4cfee6
authored
Jun 27, 2025
by
wang.yuqi
Committed by
GitHub
Jun 26, 2025
Browse files
[Model][1/N] Automatic conversion of CrossEncoding model (#20012)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
e1109306
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
239 additions
and
167 deletions
+239
-167
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+7
-4
vllm/config.py
vllm/config.py
+28
-1
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+1
-148
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+200
-0
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+3
-14
No files found.
tests/models/language/pooling/mteb_utils.py
View file @
cd4cfee6
...
...
@@ -43,7 +43,7 @@ class VllmMtebEncoder(mteb.Encoder):
# issues by randomizing the order.
r
=
self
.
rng
.
permutation
(
len
(
sentences
))
sentences
=
[
sentences
[
i
]
for
i
in
r
]
outputs
=
self
.
model
.
e
ncode
(
sentences
,
use_tqdm
=
False
)
outputs
=
self
.
model
.
e
mbed
(
sentences
,
use_tqdm
=
False
)
embeds
=
np
.
array
(
outputs
)
embeds
=
embeds
[
np
.
argsort
(
r
)]
return
embeds
...
...
@@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner,
with
vllm_runner
(
model_info
.
name
,
task
=
"score"
,
max_model_len
=
None
,
max_num_seqs
=
8
,
**
vllm_extra_kwargs
)
as
vllm_model
:
model_config
=
vllm_model
.
model
.
llm_engine
.
model_config
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
vllm_model
.
model
.
llm_engine
.
model_config
.
architectures
)
assert
(
model_info
.
architecture
in
model_config
.
architectures
)
assert
model_config
.
hf_config
.
num_labels
==
1
vllm_main_score
=
run_mteb_rerank
(
VllmMtebEncoder
(
vllm_model
),
tasks
=
MTEB_RERANK_TASKS
,
languages
=
MTEB_RERANK_LANGS
)
vllm_dtype
=
vllm_model
.
model
.
llm_engine
.
model_config
.
dtype
vllm_dtype
=
model_config
.
dtype
with
hf_runner
(
model_info
.
name
,
is_cross_encoder
=
True
,
dtype
=
"float32"
)
as
hf_model
:
...
...
vllm/config.py
View file @
cd4cfee6
...
...
@@ -569,6 +569,10 @@ class ModelConfig:
else
:
self
.
truncation_side
=
"right"
model_info
,
arch
=
self
.
registry
.
inspect_model_cls
(
self
.
architectures
)
self
.
_model_info
=
model_info
self
.
_architecture
=
arch
self
.
pooler_config
=
self
.
_init_pooler_config
()
self
.
dtype
=
_get_and_verify_dtype
(
...
...
@@ -660,8 +664,18 @@ class ModelConfig:
@
property
def
architectures
(
self
)
->
list
[
str
]:
# architectures in the model config.
return
getattr
(
self
.
hf_config
,
"architectures"
,
[])
@
property
def
architecture
(
self
)
->
str
:
# The architecture vllm actually used.
return
self
.
_architecture
@
property
def
model_info
(
self
)
->
dict
[
str
,
Any
]:
return
self
.
_model_info
def
maybe_pull_model_tokenizer_for_s3
(
self
,
model
:
str
,
tokenizer
:
str
)
->
None
:
"""Pull model/tokenizer from S3 to temporary directory when needed.
...
...
@@ -4450,6 +4464,9 @@ class VllmConfig:
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""
self
.
try_verify_and_update_config
()
if
self
.
model_config
is
not
None
:
self
.
model_config
.
verify_async_output_proc
(
self
.
parallel_config
,
self
.
speculative_config
,
...
...
@@ -4694,11 +4711,21 @@ class VllmConfig:
batch_size_capture_list
)
def
recalculate_max_model_len
(
self
,
max_model_len
:
int
):
# Can only be called in try_verify_and_update_config
model_config
=
self
.
model_config
max_model_len
=
model_config
.
get_and_verify_max_len
(
max_model_len
)
self
.
model_config
.
max_model_len
=
max_model_len
self
.
scheduler_config
.
max_model_len
=
max_model_len
self
.
compute_hash
()
def
try_verify_and_update_config
(
self
):
architecture
=
getattr
(
self
.
model_config
,
"architecture"
,
None
)
if
architecture
is
None
:
return
from
vllm.model_executor.models.config
import
MODELS_CONFIG_MAP
cls
=
MODELS_CONFIG_MAP
.
get
(
architecture
,
None
)
if
cls
is
not
None
:
cls
.
verify_and_update_config
(
self
)
def
__str__
(
self
):
return
(
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
cd4cfee6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
copy
import
deepcopy
from
typing
import
Optional
import
torch
...
...
@@ -12,7 +11,6 @@ from vllm.attention import Attention, AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
(
get_act_and_mul_fn
,
get_act_fn
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -30,8 +28,6 @@ from vllm.model_executor.models.interfaces import SupportsQuant
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.sequence
import
IntermediateTensors
logger
=
init_logger
(
__name__
)
class
BertWithRopeEmbedding
(
nn
.
Module
):
...
...
@@ -408,7 +404,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
vllm_config
=
vllm_config
self
.
config
=
self
.
config_verify
(
vllm
_config
)
self
.
config
=
vllm_config
.
model_config
.
hf
_config
self
.
embeddings
=
BertWithRopeEmbedding
(
self
.
config
)
self
.
encoder
=
BertWithRopeEncoder
(
vllm_config
=
vllm_config
,
...
...
@@ -416,9 +412,6 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
rotary_kwargs
=
self
.
config
.
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.encoder"
)
def
config_verify
(
self
,
vllm_config
):
raise
NotImplementedError
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
@@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
"norm2"
:
"mlp_ln"
,
})
def
config_verify
(
self
,
vllm_config
):
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"NomicBertConfig"
assert
config
.
activation_function
in
[
"swiglu"
,
"gelu"
]
config
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"rope"
)
if
config
.
activation_function
==
"swiglu"
:
config
.
hidden_act
=
"silu"
else
:
config
.
hidden_act
=
config
.
activation_function
assert
(
config
.
mlp_fc1_bias
==
config
.
mlp_fc2_bias
==
config
.
qkv_proj_bias
)
config
.
bias
=
config
.
qkv_proj_bias
assert
config
.
rotary_emb_scale_base
is
None
assert
not
config
.
rotary_emb_interleaved
config
.
layer_norm_eps
=
config
.
layer_norm_epsilon
config
.
intermediate_size
=
config
.
n_inner
config
.
hidden_size
=
config
.
n_embd
config
.
num_hidden_layers
=
config
.
n_layer
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_emb_dim
=
head_dim
*
config
.
rotary_emb_fraction
max_trained_positions
=
getattr
(
config
,
"max_trained_positions"
,
2048
)
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
rotary_emb_dim
,
"max_position"
:
max_trained_positions
,
"base"
:
getattr
(
config
,
"rope_theta"
,
config
.
rotary_emb_base
),
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if
(
not
vllm_config
.
model_config
.
hf_overrides
and
vllm_config
.
model_config
.
original_max_model_len
is
None
):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before
=
vllm_config
.
model_config
.
max_model_len
max_model_len
=
min
(
vllm_config
.
model_config
.
max_model_len
,
max_trained_positions
)
vllm_config
.
recalculate_max_model_len
(
max_model_len
)
logger
.
warning
(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html"
,
max_model_len_before
,
vllm_config
.
model_config
.
max_model_len
)
else
:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config
=
vllm_config
.
model_config
hf_text_config
=
model_config
.
hf_text_config
if
isinstance
(
model_config
.
hf_overrides
,
dict
):
# hf_overrides_kw
max_model_len
=
model_config
.
hf_overrides
.
get
(
"max_model_len"
,
vllm_config
.
model_config
.
max_model_len
)
else
:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# reset hf_text_config for recalculate_max_model_len.
if
hasattr
(
hf_text_config
,
"max_model_len"
):
delattr
(
hf_text_config
,
"max_model_len"
)
hf_text_config
.
max_position_embeddings
=
max_trained_positions
hf_text_config
.
rope_scaling
=
config
.
rotary_kwargs
[
"rope_scaling"
]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config
=
deepcopy
(
model_config
.
encoder_config
)
encoder_config
.
pop
(
"max_seq_length"
,
None
)
model_config
.
encoder_config
=
encoder_config
vllm_config
.
recalculate_max_model_len
(
max_model_len
)
return
config
class
GteNewModel
(
BertWithRope
):
# for https://huggingface.co/Alibaba-NLP/new-impl
...
...
@@ -600,24 +504,6 @@ class GteNewModel(BertWithRope):
layer
.
mlp
.
gate_up_proj
.
bias
=
None
layer
.
mlp
.
gate_up_proj
.
skip_bias_add
=
True
def
config_verify
(
self
,
vllm_config
):
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"NewConfig"
assert
config
.
hidden_act
==
"gelu"
config
.
hidden_act
=
"geglu"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
config
.
rope_theta
,
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
return
config
def
split_up_gate_proj
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
n
=
"mlp.up_gate_proj"
for
name
,
weight
in
weights
:
...
...
@@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
"attention.o_proj"
:
"attn.out_proj"
,
})
def
config_verify
(
self
,
vllm_config
):
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"GteConfig"
assert
config
.
hidden_act
==
"gelu"
config
.
hidden_act
=
"geglu"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
config
.
rope_theta
,
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
return
config
class
JinaRobertaModel
(
BertWithRope
):
# for https://huggingface.co/jinaai/jina-embeddings-v3
...
...
@@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
"norm2"
:
"mlp_ln"
,
})
def
config_verify
(
self
,
vllm_config
):
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"XLMRobertaFlashConfig"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
getattr
(
config
,
"rope_theta"
,
config
.
rotary_emb_base
),
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
return
config
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/config.py
0 → 100644
View file @
cd4cfee6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
copy
import
deepcopy
from
typing
import
TYPE_CHECKING
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
logger
=
init_logger
(
__name__
)
class
VerifyAndUpdateConfig
:
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
raise
NotImplementedError
class
GteNewModelConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"NewConfig"
assert
config
.
hidden_act
==
"gelu"
config
.
hidden_act
=
"geglu"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
config
.
rope_theta
,
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
class
JinaRobertaModelConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
if
config
.
position_embedding_type
==
"rotary"
:
assert
config
.
__class__
.
__name__
==
"XLMRobertaFlashConfig"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
getattr
(
config
,
"rope_theta"
,
config
.
rotary_emb_base
),
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
class
NomicBertModelConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"NomicBertConfig"
assert
config
.
activation_function
in
[
"swiglu"
,
"gelu"
]
config
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"rope"
)
if
config
.
activation_function
==
"swiglu"
:
config
.
hidden_act
=
"silu"
else
:
config
.
hidden_act
=
config
.
activation_function
assert
(
config
.
mlp_fc1_bias
==
config
.
mlp_fc2_bias
==
config
.
qkv_proj_bias
)
config
.
bias
=
config
.
qkv_proj_bias
assert
config
.
rotary_emb_scale_base
is
None
assert
not
config
.
rotary_emb_interleaved
config
.
layer_norm_eps
=
config
.
layer_norm_epsilon
config
.
intermediate_size
=
config
.
n_inner
config
.
hidden_size
=
config
.
n_embd
config
.
num_hidden_layers
=
config
.
n_layer
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_emb_dim
=
head_dim
*
config
.
rotary_emb_fraction
max_trained_positions
=
getattr
(
config
,
"max_trained_positions"
,
2048
)
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
rotary_emb_dim
,
"max_position"
:
max_trained_positions
,
"base"
:
getattr
(
config
,
"rope_theta"
,
config
.
rotary_emb_base
),
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if
(
not
vllm_config
.
model_config
.
hf_overrides
and
vllm_config
.
model_config
.
original_max_model_len
is
None
):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before
=
vllm_config
.
model_config
.
max_model_len
max_model_len
=
min
(
vllm_config
.
model_config
.
max_model_len
,
max_trained_positions
)
vllm_config
.
recalculate_max_model_len
(
max_model_len
)
logger
.
warning
(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html"
,
max_model_len_before
,
vllm_config
.
model_config
.
max_model_len
)
else
:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config
=
vllm_config
.
model_config
hf_text_config
=
model_config
.
hf_text_config
if
isinstance
(
model_config
.
hf_overrides
,
dict
):
# hf_overrides_kw
max_model_len
=
model_config
.
hf_overrides
.
get
(
"max_model_len"
,
vllm_config
.
model_config
.
max_model_len
)
else
:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# reset hf_text_config for recalculate_max_model_len.
if
hasattr
(
hf_text_config
,
"max_model_len"
):
delattr
(
hf_text_config
,
"max_model_len"
)
hf_text_config
.
max_position_embeddings
=
max_trained_positions
hf_text_config
.
rope_scaling
=
config
.
rotary_kwargs
[
"rope_scaling"
]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config
=
deepcopy
(
model_config
.
encoder_config
)
encoder_config
.
pop
(
"max_seq_length"
,
None
)
model_config
.
encoder_config
=
encoder_config
vllm_config
.
recalculate_max_model_len
(
max_model_len
)
class
Qwen3ForSequenceClassificationConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
is_original_qwen3_reranker
=
getattr
(
config
,
"is_original_qwen3_reranker"
,
False
)
if
not
is_original_qwen3_reranker
:
return
tokens
=
getattr
(
config
,
"classifier_from_token"
,
None
)
assert
tokens
is
not
None
and
len
(
tokens
)
==
2
,
\
(
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
)
config
.
num_labels
=
1
class
SnowflakeGteNewModelConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
config
=
vllm_config
.
model_config
.
hf_config
assert
config
.
__class__
.
__name__
==
"GteConfig"
assert
config
.
hidden_act
==
"gelu"
config
.
hidden_act
=
"geglu"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"base"
:
config
.
rope_theta
,
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
MODELS_CONFIG_MAP
:
dict
[
str
,
type
[
VerifyAndUpdateConfig
]]
=
{
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteNewModel"
:
GteNewModelConfig
,
"NomicBertModel"
:
NomicBertModelConfig
,
"Qwen3ForSequenceClassification"
:
Qwen3ForSequenceClassificationConfig
,
"XLMRobertaModel"
:
JinaRobertaModelConfig
,
}
vllm/model_executor/models/qwen3.py
View file @
cd4cfee6
...
...
@@ -400,22 +400,10 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
def
load_weights_from_original_qwen3_reranker
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
assert
tokens
is
not
None
and
len
(
tokens
)
==
2
,
\
(
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py"
)
self
.
config
.
num_labels
=
1
model_config
=
self
.
vllm_config
.
model_config
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
device
=
self
.
score
.
weight
.
device
self
.
score
=
RowParallelLinear
(
self
.
config
.
hidden_size
,
self
.
config
.
num_labels
,
quant_config
=
self
.
quant_config
,
input_is_parallel
=
False
,
bias
=
False
,
prefix
=
maybe_prefix
(
self
.
prefix
,
"score"
)).
to
(
device
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
...
...
@@ -443,5 +431,6 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
self
.
score
.
weight
.
data
.
copy_
(
weight
)
del
self
.
lm_head
loaded_weights
.
add
(
"
classifier
.weight"
)
loaded_weights
.
add
(
"
score
.weight"
)
loaded_weights
.
discard
(
"lm_head.weight"
)
return
loaded_weights
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment