Unverified Commit f600d519 authored by Hanjun Cho's avatar Hanjun Cho Committed by GitHub
Browse files

[Bugfix] Fix score layer quantization for sequence classification models -...


[Bugfix] Fix score layer quantization for sequence classification models  - Qwen3 (VL) Reranker (#35849)
Signed-off-by: default avatarHanjun Cho <gkswns0531@gmail.com>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 8e782013
...@@ -288,15 +288,37 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -288,15 +288,37 @@ def as_seq_cls_model(cls: _T) -> _T:
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
prefix: str = "", prefix: str = "",
) -> "Pooler": ) -> "Pooler":
text_config = vllm_config.model_config.hf_config.get_text_config() hf_config = vllm_config.model_config.hf_config
text_config = hf_config.get_text_config()
model_config = vllm_config.model_config model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
# Check if score weights are derived online from LM head
# (same condition as load_weights branch)
tokens = getattr(
hf_config,
"classifier_from_token",
getattr(text_config, "classifier_from_token", None),
)
method = getattr(
hf_config,
"method",
getattr(text_config, "method", None),
)
# Online conversion: no score weights in checkpoint, don't
# quantize (small output_dim breaks FP8/Marlin tile alignment).
# Checkpoint-based: respect the model's quant_config.
quant_config = (
None
if (tokens is not None or method is not None)
else vllm_config.quant_config
)
self.score = ReplicatedLinear( self.score = ReplicatedLinear(
model_config.get_hidden_size(), model_config.get_hidden_size(),
text_config.num_labels, text_config.num_labels,
bias=False, bias=False,
params_dtype=vllm_config.model_config.head_dtype, params_dtype=model_config.head_dtype,
quant_config=quant_config, quant_config=quant_config,
return_bias=False, return_bias=False,
prefix=maybe_prefix(prefix, "score"), prefix=maybe_prefix(prefix, "score"),
...@@ -452,7 +474,6 @@ def load_weights_using_from_2_way_softmax( ...@@ -452,7 +474,6 @@ def load_weights_using_from_2_way_softmax(
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
model_config = model.vllm_config.model_config model_config = model.vllm_config.model_config
quant_config = model.vllm_config.quant_config
hf_config = model.config hf_config = model.config
text_config = hf_config.get_text_config() text_config = hf_config.get_text_config()
...@@ -469,7 +490,8 @@ def load_weights_using_from_2_way_softmax( ...@@ -469,7 +490,8 @@ def load_weights_using_from_2_way_softmax(
using_vlm_head = is_vlm and hasattr(language_model, "score") using_vlm_head = is_vlm and hasattr(language_model, "score")
language_model.lm_head = ParallelLMHead( language_model.lm_head = ParallelLMHead(
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config text_config.vocab_size,
text_config.hidden_size,
) )
if text_config.tie_word_embeddings: if text_config.tie_word_embeddings:
# embed_tokens is the assumed name for input embeddings. If the model does not # embed_tokens is the assumed name for input embeddings. If the model does not
...@@ -531,7 +553,6 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te ...@@ -531,7 +553,6 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
model_config = model.vllm_config.model_config model_config = model.vllm_config.model_config
quant_config = model.vllm_config.quant_config
text_config = model.config.get_text_config() text_config = model.config.get_text_config()
tokens = getattr(text_config, "classifier_from_token", []) tokens = getattr(text_config, "classifier_from_token", [])
...@@ -543,7 +564,8 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te ...@@ -543,7 +564,8 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
using_vlm_head = is_vlm and hasattr(language_model, "score") using_vlm_head = is_vlm and hasattr(language_model, "score")
language_model.lm_head = ParallelLMHead( language_model.lm_head = ParallelLMHead(
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config text_config.vocab_size,
text_config.hidden_size,
) )
if text_config.tie_word_embeddings: if text_config.tie_word_embeddings:
# embed_tokens is the assumed name for input embeddings. If the model does not # embed_tokens is the assumed name for input embeddings. If the model does not
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment