Unverified Commit 84cf78ac authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent 16fb668b
......@@ -182,8 +182,8 @@ def as_seq_cls_model(cls: _T) -> _T:
assert pooler_config is not None
pooling_type_str = pooler_config.pooling_type
pooling_type = (PoolingType.LAST if pooling_type_str is None else
PoolingType[pooling_type_str])
assert pooling_type_str is not None
pooling_type = PoolingType[pooling_type_str]
self.pooler = DispatchPooler({
"encode":
......
......@@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces import (SupportsCrossEncoding, SupportsQuant,
default_pooling_type)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
......@@ -327,6 +328,7 @@ class BertOutput(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
......@@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant):
return loaded_params
@default_pooling_type("ALL")
class BertPoolingModel(BertModel):
is_pooling_model = True
......@@ -431,6 +434,7 @@ class BertPoolingModel(BertModel):
return loaded_params
@default_pooling_type("CLS")
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
......@@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"embed":
Pooler.for_embed(
pooler_config,
default_pooling_type=PoolingType.CLS,
),
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
})
......@@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
return token_type_ids
@default_pooling_type("CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
......
......@@ -27,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.interfaces import (SupportsQuant,
default_pooling_type)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -401,6 +402,7 @@ class BertWithRopeEncoder(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
......
......@@ -641,6 +641,20 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)
def default_pooling_type(pooling_type: str) -> object:
"""Set default_pooling_type decorator. """
def func(model: object):
model.default_pooling_type = pooling_type
return model
return func
def get_default_pooling_type(model: Union[type[object], object]) -> str:
return getattr(model, "default_pooling_type", "LAST")
class SupportsQuant:
"""The interface required for all models that support quantization."""
......
......@@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -401,6 +401,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loaded_params
@default_pooling_type("ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
......
......@@ -22,8 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
......@@ -604,6 +603,5 @@ class JambaForSequenceClassification(JambaForCausalLM):
Pooler.for_classify(
pooler_config,
classifier=self.score,
default_pooling_type=PoolingType.LAST,
),
})
......@@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
default_pooling_type)
from .utils import WeightsMapper, maybe_prefix
......@@ -201,6 +202,7 @@ class ModernBertEncoderLayer(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
......@@ -264,7 +266,6 @@ class ModernBertPooler(Pooler):
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.pooling_type = config.classifier_pooling
self.act = nn.GELU()
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
......@@ -294,6 +295,7 @@ class ModernBertPooler(Pooler):
return pooled_output
@default_pooling_type("CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding):
......
......@@ -15,11 +15,10 @@ from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, maybe_prefix
......@@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights)
@default_pooling_type("ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
{"encode": Pooler.for_encode(pooler_config)}, )
@default_pooling_type("STEP")
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(
pooler_config,
default_pooling_type=PoolingType.STEP,
)
})
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)})
......@@ -25,8 +25,8 @@ from vllm.logger import init_logger
from vllm.transformers_utils.dynamic_module import (
try_get_class_from_dynamic_module)
from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
is_attention_free, is_hybrid, supports_cross_encoding,
supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_pooling_model, is_text_generation_model
......@@ -305,6 +305,7 @@ class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_pooling_model: bool
default_pooling_type: str
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
......@@ -323,6 +324,7 @@ class _ModelInfo:
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model),
default_pooling_type=get_default_pooling_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
......
......@@ -23,7 +23,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from vllm.sequence import IntermediateTensors
from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding
from .interfaces import SupportsCrossEncoding, default_pooling_type
class RobertaEmbedding(nn.Module):
......@@ -86,6 +86,7 @@ class RobertaClassificationHead(nn.Module):
return x
@default_pooling_type("CLS")
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
......@@ -149,6 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper)
@default_pooling_type("CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
......
......@@ -1272,7 +1272,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not is_pooling_model(model):
return []
return list(model.pooler.get_supported_tasks())
supported_tasks = list(model.pooler.get_supported_tasks())
if (self.scheduler_config.chunked_prefill_enabled
and "encode" in supported_tasks):
supported_tasks.remove("encode")
logger.info_once("Chunked prefill is not supported with "
"encode task which using ALL pooling. "
"Please turn off chunked prefill by "
"`--no-enable-chunked-prefill` before using it.")
return supported_tasks
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment