Unverified Commit 13370712 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Replace embedding models with pooling adapter (#10769)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 7e4bbda5
......@@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
......
......@@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.quant_config = quant_config
config.text_config.architectures = ["GemmaForCausalLM"]
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
......
......@@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
......@@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
......@@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
self.language_model = LlamaForCausalLM(vllm_config=vllm_config,
prefix="")
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
prefix="",
# We don't directly initialize vLLM's LlamaForCausalLM so we
# can automatically apply embedding wrapper if this model is
# initialized as an embedding model
architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
......@@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(
......
......@@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(
......
......@@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class Qwen2MLP(nn.Module):
......@@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
......@@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
......@@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
......@@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
# after changing the default pooling method
if pooler_config.pooling_type is None:
logger.warning(
"This embedding model will default to last-token pooling in "
"an upcoming version. To avoid breaking changes, you should "
"pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`"
" explicitly.")
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.MEAN,
......
......@@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
......@@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor
......@@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching"
......@@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
......@@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......
......@@ -20,6 +20,7 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal,
supports_pp)
......@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
# Multiple models share the same architecture, so we include them all
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
if arch == "LlamaForCausalLM"
},
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"MistralModel": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
......@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
}
_CROSS_ENCODER_MODELS = {
......@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
@dataclass(frozen=True)
class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_embedding_model: bool
supports_cross_encoding: bool
......@@ -218,9 +220,19 @@ class _ModelInfo:
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model)
if not is_embedding_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_embedding_model_ = True
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model),
is_embedding_model=is_embedding_model_,
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
......@@ -399,13 +411,13 @@ class _ModelRegistry:
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> _ModelInfo:
) -> Tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info
return (model_info, arch)
return self._raise_for_unsupported(architectures)
......@@ -426,39 +438,50 @@ class _ModelRegistry:
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_text_generation_model
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model
def is_embedding_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model
def is_cross_encoder_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_cross_encoding
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_cross_encoding
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_pp
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def model_has_inner_state(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
def is_attention_free_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_attention_free
ModelRegistry = _ModelRegistry({
......
......@@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
......
......@@ -173,8 +173,15 @@ class AutoWeightsLoader:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
loaded_params = module_load_weights(weights)
yield from map(lambda x: self._get_qualname(base_prefix, x),
loaded_params)
if loaded_params is None:
logger.warning(
"Unable to collect loaded parameters "
"for module %s", module)
else:
yield from map(
lambda x: self._get_qualname(base_prefix, x),
loaded_params,
)
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
......@@ -232,17 +239,24 @@ class AutoWeightsLoader:
def init_vllm_registered_model(
hf_config: PretrainedConfig,
vllm_config: VllmConfig,
*,
prefix: str = "",
hf_config: Optional[PretrainedConfig] = None,
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""
Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model.
"""
from vllm.model_executor.model_loader.loader import _initialize_model
vllm_config = vllm_config.with_hf_config(hf_config)
return _initialize_model(vllm_config, prefix)
if hf_config is not None:
vllm_config = vllm_config.with_hf_config(hf_config)
return _initialize_model(vllm_config=vllm_config,
prefix=prefix,
architectures=architectures)
@overload
......
......@@ -7,7 +7,7 @@ from torch import nn
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides,
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
if TYPE_CHECKING:
......@@ -54,8 +54,8 @@ class MultiModalPlugin(ABC):
"""
def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
@abstractmethod
def get_data_key(self) -> str:
......
......@@ -9,6 +9,7 @@ from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import ClassRegistry
from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
......@@ -62,8 +63,8 @@ class MultiModalRegistry:
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories: Dict[Type[nn.Module],
MultiModalProcessorFactory] = {}
self._processor_factories = ClassRegistry[nn.Module,
MultiModalProcessorFactory]()
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
......
......@@ -20,7 +20,7 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import defaultdict
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps
from platform import uname
......@@ -1517,13 +1517,13 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):
class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}
def __getitem__(self, key) -> T:
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
......@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
if not isinstance(key, type):
return False
return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Create a weak reference to a tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment