Unverified Commit 84cf78ac authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent 16fb668b
...@@ -182,8 +182,8 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -182,8 +182,8 @@ def as_seq_cls_model(cls: _T) -> _T:
assert pooler_config is not None assert pooler_config is not None
pooling_type_str = pooler_config.pooling_type pooling_type_str = pooler_config.pooling_type
pooling_type = (PoolingType.LAST if pooling_type_str is None else assert pooling_type_str is not None
PoolingType[pooling_type_str]) pooling_type = PoolingType[pooling_type_str]
self.pooler = DispatchPooler({ self.pooler = DispatchPooler({
"encode": "encode":
......
...@@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata ...@@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces import (SupportsCrossEncoding, SupportsQuant,
default_pooling_type)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
...@@ -327,6 +328,7 @@ class BertOutput(nn.Module): ...@@ -327,6 +328,7 @@ class BertOutput(nn.Module):
@support_torch_compile @support_torch_compile
@default_pooling_type("CLS")
class BertModel(nn.Module, SupportsQuant): class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True is_pooling_model = True
...@@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant):
return loaded_params return loaded_params
@default_pooling_type("ALL")
class BertPoolingModel(BertModel): class BertPoolingModel(BertModel):
is_pooling_model = True is_pooling_model = True
...@@ -431,6 +434,7 @@ class BertPoolingModel(BertModel): ...@@ -431,6 +434,7 @@ class BertPoolingModel(BertModel):
return loaded_params return loaded_params
@default_pooling_type("CLS")
class BertEmbeddingModel(nn.Module, SupportsQuant): class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
...@@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): ...@@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler({ return DispatchPooler({
"encode": "encode": Pooler.for_encode(pooler_config),
Pooler.for_encode(pooler_config), "embed": Pooler.for_embed(pooler_config),
"embed":
Pooler.for_embed(
pooler_config,
default_pooling_type=PoolingType.CLS,
),
}) })
...@@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: ...@@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
return token_type_ids return token_type_ids
@default_pooling_type("CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant): SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
......
...@@ -27,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -27,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.interfaces import (SupportsQuant,
default_pooling_type)
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -401,6 +402,7 @@ class BertWithRopeEncoder(nn.Module): ...@@ -401,6 +402,7 @@ class BertWithRopeEncoder(nn.Module):
@support_torch_compile @support_torch_compile
@default_pooling_type("CLS")
class BertWithRope(nn.Module, SupportsQuant): class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
......
...@@ -641,6 +641,20 @@ def supports_cross_encoding( ...@@ -641,6 +641,20 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)
def default_pooling_type(pooling_type: str) -> object:
"""Set default_pooling_type decorator. """
def func(model: object):
model.default_pooling_type = pooling_type
return model
return func
def get_default_pooling_type(model: Union[type[object], object]) -> str:
return getattr(model, "default_pooling_type", "LAST")
class SupportsQuant: class SupportsQuant:
"""The interface required for all models that support quantization.""" """The interface required for all models that support quantization."""
......
...@@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -401,6 +401,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -401,6 +401,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loaded_params return loaded_params
@default_pooling_type("ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM): class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True is_pooling_model = True
......
...@@ -22,8 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -22,8 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator) MambaStateShapeCalculator)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
...@@ -604,6 +603,5 @@ class JambaForSequenceClassification(JambaForCausalLM): ...@@ -604,6 +603,5 @@ class JambaForSequenceClassification(JambaForCausalLM):
Pooler.for_classify( Pooler.for_classify(
pooler_config, pooler_config,
classifier=self.score, classifier=self.score,
default_pooling_type=PoolingType.LAST,
), ),
}) })
...@@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata ...@@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
default_pooling_type)
from .utils import WeightsMapper, maybe_prefix from .utils import WeightsMapper, maybe_prefix
...@@ -201,6 +202,7 @@ class ModernBertEncoderLayer(nn.Module): ...@@ -201,6 +202,7 @@ class ModernBertEncoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
@default_pooling_type("CLS")
class ModernBertModel(nn.Module): class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."}) orig_to_new_prefix={"layers.": "encoder_layer.layers."})
...@@ -264,7 +266,6 @@ class ModernBertPooler(Pooler): ...@@ -264,7 +266,6 @@ class ModernBertPooler(Pooler):
self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size, self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias) config.classifier_bias)
self.pooling_type = config.classifier_pooling
self.act = nn.GELU() self.act = nn.GELU()
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps, eps=config.norm_eps,
...@@ -294,6 +295,7 @@ class ModernBertPooler(Pooler): ...@@ -294,6 +295,7 @@ class ModernBertPooler(Pooler):
return pooled_output return pooled_output
@default_pooling_type("CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding): SupportsCrossEncoding):
......
...@@ -15,11 +15,10 @@ from torch import nn ...@@ -15,11 +15,10 @@ from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
PoolingType)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
...@@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights) return loader.load_weights(weights)
@default_pooling_type("ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel): class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): ...@@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
{"encode": Pooler.for_encode(pooler_config)}, ) {"encode": Pooler.for_encode(pooler_config)}, )
@default_pooling_type("STEP")
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): ...@@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler({ self.pooler = DispatchPooler(
"encode": {"encode": Pooler.for_encode(pooler_config)})
Pooler.for_encode(
pooler_config,
default_pooling_type=PoolingType.STEP,
)
})
...@@ -25,8 +25,8 @@ from vllm.logger import init_logger ...@@ -25,8 +25,8 @@ from vllm.logger import init_logger
from vllm.transformers_utils.dynamic_module import ( from vllm.transformers_utils.dynamic_module import (
try_get_class_from_dynamic_module) try_get_class_from_dynamic_module)
from .interfaces import (has_inner_state, has_noops, is_attention_free, from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
is_hybrid, supports_cross_encoding, is_attention_free, is_hybrid, supports_cross_encoding,
supports_multimodal, supports_multimodal_raw_input, supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only) supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_pooling_model, is_text_generation_model from .interfaces_base import is_pooling_model, is_text_generation_model
...@@ -305,6 +305,7 @@ class _ModelInfo: ...@@ -305,6 +305,7 @@ class _ModelInfo:
architecture: str architecture: str
is_text_generation_model: bool is_text_generation_model: bool
is_pooling_model: bool is_pooling_model: bool
default_pooling_type: str
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input: bool supports_multimodal_raw_input: bool
...@@ -323,6 +324,7 @@ class _ModelInfo: ...@@ -323,6 +324,7 @@ class _ModelInfo:
architecture=model.__name__, architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model), is_pooling_model=is_pooling_model(model),
default_pooling_type=get_default_pooling_type(model),
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_multimodal_raw_input=supports_multimodal_raw_input(model),
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, ...@@ -23,7 +23,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .bert_with_rope import BertWithRope, JinaRobertaModel from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding, default_pooling_type
class RobertaEmbedding(nn.Module): class RobertaEmbedding(nn.Module):
...@@ -86,6 +86,7 @@ class RobertaClassificationHead(nn.Module): ...@@ -86,6 +86,7 @@ class RobertaClassificationHead(nn.Module):
return x return x
@default_pooling_type("CLS")
class RobertaEmbeddingModel(BertEmbeddingModel): class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities. """A model that uses Roberta to provide embedding functionalities.
...@@ -149,6 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): ...@@ -149,6 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper) return loader.load_weights(weights_list, mapper=mapper)
@default_pooling_type("CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities. """A model that uses Roberta to provide embedding functionalities.
......
...@@ -1272,7 +1272,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1272,7 +1272,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not is_pooling_model(model): if not is_pooling_model(model):
return [] return []
return list(model.pooler.get_supported_tasks()) supported_tasks = list(model.pooler.get_supported_tasks())
if (self.scheduler_config.chunked_prefill_enabled
and "encode" in supported_tasks):
supported_tasks.remove("encode")
logger.info_once("Chunked prefill is not supported with "
"encode task which using ALL pooling. "
"Please turn off chunked prefill by "
"`--no-enable-chunked-prefill` before using it.")
return supported_tasks
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]() tasks = list[SupportedTask]()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment