Unverified Commit 802748bd authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Bugfix] Fix Qwen3-Reranker-8B load (#28117)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent faedbb4d
...@@ -186,15 +186,21 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: ...@@ -186,15 +186,21 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
raise NotImplementedError raise NotImplementedError
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
load_lm_head: bool = False,
):
# TODO: Support uninitialized params tracking # TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it # For most pooling models: We have deleted this attribute, so don't load it.
weights = ( # For converting an LLM into a seq cls model, we need the lm_head.
(name, data) if not load_lm_head:
for name, data in weights weights = (
if not name.startswith("lm_head.") (name, data)
) for name, data in weights
if not name.startswith("lm_head.")
)
# If `*ForCausalLM` defines `load_weights` on the inner model # If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters, # and there are no other inner modules with parameters,
...@@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax( ...@@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax(
) )
model.lm_head = model.lm_head.tie_weights(embed_tokens) model.lm_head = model.lm_head.tie_weights(embed_tokens)
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion # ModelForPooling is dynamically defined inside the _create_pooling_model_cls
loaded_weights = type(model).__mro__[1].load_weights(model, weights) # function, so we need use this hacky method to obtain it.
pooling_model_cls = next(
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
)
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment