Unverified Commit 13370712 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Replace embedding models with pooling adapter (#10769)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 7e4bbda5
...@@ -334,7 +334,6 @@ steps: ...@@ -334,7 +334,6 @@ steps:
commands: commands:
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model - pytest -v -s models/embedding/language -m core_model
- pytest -v -s models/embedding/vision_language -m core_model
- label: Language Models Test (Extended) # 50min - label: Language Models Test (Extended) # 50min
optional: true optional: true
...@@ -346,7 +345,6 @@ steps: ...@@ -346,7 +345,6 @@ steps:
commands: commands:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model' - pytest -v -s models/embedding/language -m 'not core_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- label: Multi-Modal Models Test (Standard) # 26min - label: Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
...@@ -359,6 +357,7 @@ steps: ...@@ -359,6 +357,7 @@ steps:
commands: commands:
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model
...@@ -376,6 +375,7 @@ steps: ...@@ -376,6 +375,7 @@ steps:
# https://github.com/huggingface/transformers/issues/34307 # https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model'
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
......
...@@ -357,7 +357,7 @@ Text Embedding ...@@ -357,7 +357,7 @@ Text Embedding
- ✅︎ - ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based - Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM` * - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
...@@ -378,6 +378,10 @@ Text Embedding ...@@ -378,6 +378,10 @@ Text Embedding
.. tip:: .. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`. You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
.. note:: .. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention. Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly. You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
...@@ -397,12 +401,21 @@ Reward Modeling ...@@ -397,12 +401,21 @@ Reward Modeling
- Example HF Models - Example HF Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>` - :ref:`PP <distributed_serving>`
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2ForRewardModel` * - :code:`Qwen2ForRewardModel`
- Qwen2-based - Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc. - :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
.. important::
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
.. note:: .. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API. As an interim measure, these models are supported in both offline and online inference via Embeddings API.
......
...@@ -263,7 +263,6 @@ class HfRunner: ...@@ -263,7 +263,6 @@ class HfRunner:
dtype: str = "half", dtype: str = "half",
*, *,
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_sentence_transformer: bool = False, is_sentence_transformer: bool = False,
is_cross_encoder: bool = False, is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
......
...@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`. ...@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
""" """
import pytest import pytest
from vllm.config import PoolerConfig
from ..utils import check_embeddings_close from ..utils import check_embeddings_close
...@@ -33,6 +35,9 @@ def test_models( ...@@ -33,6 +35,9 @@ def test_models(
dtype: str, dtype: str,
) -> None: ) -> None:
vllm_extra_kwargs = {} vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN")
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct": if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False} vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
......
...@@ -6,11 +6,8 @@ import torch.cuda ...@@ -6,11 +6,8 @@ import torch.cuda
from vllm.model_executor.models import (is_embedding_model, from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model, is_text_generation_model,
supports_multimodal) supports_multimodal)
# yapf conflicts with isort for this block from vllm.model_executor.models.adapters import as_embedding_model
# yapf: disable from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS, _SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS, _TEXT_GENERATION_MODELS,
ModelRegistry) ModelRegistry)
...@@ -26,18 +23,18 @@ def test_registry_imports(model_arch): ...@@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch) model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
if model_arch in _SPECULATIVE_DECODING_MODELS: if model_arch in _SPECULATIVE_DECODING_MODELS:
pass # Ignore these models which do not have a unified format return # Ignore these models which do not have a unified format
else:
assert is_text_generation_model(model_cls) is ( if (model_arch in _TEXT_GENERATION_MODELS
model_arch in _TEXT_GENERATION_MODELS or model_arch in _MULTIMODAL_MODELS):
or model_arch in _MULTIMODAL_MODELS) assert is_text_generation_model(model_cls)
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS} # All vLLM models should be convertible to an embedding model
assert is_embedding_model(model_cls) is (model_arch embed_model = as_embedding_model(model_cls)
in embedding_models) assert is_embedding_model(embed_model)
assert supports_multimodal(model_cls) is (model_arch if model_arch in _MULTIMODAL_MODELS:
in _MULTIMODAL_MODELS) assert supports_multimodal(model_cls)
@fork_new_process_for_each_test @fork_new_process_for_each_test
......
from typing import List, Optional, Union from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel from vllm.config import VllmConfig
from vllm.sequence import IntermediateTensors from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
class MyGemma2Embedding(Gemma2EmbeddingModel): class MyGemma2Embedding(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel): ...@@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
...@@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel): ...@@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
# Return all-zero embeddings # Return all-zero embeddings
return torch.zeros_like(hidden_states) return torch.zeros_like(hidden_states)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
return self.model.load_weights(weights)
...@@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task): ...@@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
@pytest.mark.parametrize(("model_id", "bad_task"), [ @pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"), ("Qwen/Qwen2.5-Math-RM-72B", "generate"),
("intfloat/e5-mistral-7b-instruct", "generate"),
]) ])
def test_incorrect_task(model_id, bad_task): def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"): with pytest.raises(ValueError, match=r"does not support the .* task"):
......
...@@ -370,6 +370,31 @@ class ModelConfig: ...@@ -370,6 +370,31 @@ class ModelConfig:
selected_task = next(iter(supported_tasks_lst)) selected_task = next(iter(supported_tasks_lst))
if len(supported_tasks) > 1: if len(supported_tasks) > 1:
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
# Hardcode the models that are exceptions
("AquilaModel", "generate"),
("ChatGLMModel", "generate"),
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embedding"),
("RewardModel", "embedding"),
("ForSequenceClassification", "embedding"),
]
info, arch = ModelRegistry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
selected_task = pref_task
break
else:
if (arch.endswith("Model")
and info.architecture.endswith("ForCausalLM")
and "embedding" in supported_tasks):
selected_task = "embedding"
logger.info( logger.info(
"This model supports multiple tasks: %s. " "This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task) "Defaulting to '%s'.", supported_tasks, selected_task)
......
...@@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never ...@@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs) print_warning_once, resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs from .parse import is_encoder_decoder_inputs
...@@ -136,12 +136,12 @@ class InputRegistry: ...@@ -136,12 +136,12 @@ class InputRegistry:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module], self._dummy_factories_by_model_type = \
DummyDataFactory] = {} ClassRegistry[nn.Module, DummyDataFactory]()
self._dummy_encoder_factories_by_model_type: Dict[ self._dummy_encoder_factories_by_model_type = \
Type[nn.Module], DummyDataFactory] = {} ClassRegistry[nn.Module, DummyDataFactory]()
self._input_processors_by_model_type: Dict[Type[nn.Module], self._input_processors_by_model_type = \
InputProcessor] = {} ClassRegistry[nn.Module, InputProcessor]()
def _default_dummy_data_factory( def _default_dummy_data_factory(
self, self,
......
...@@ -60,9 +60,7 @@ class Pooler(nn.Module): ...@@ -60,9 +60,7 @@ class Pooler(nn.Module):
softmax: bool, softmax: bool,
step_tag_id: Optional[int] = None, step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None, returned_token_ids: Optional[List[int]] = None,
) -> Optional["Pooler"]: ) -> "Pooler":
if pooler_config is None:
return None
return cls( return cls(
pooling_type=PoolingType[pooler_config.pooling_type] pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type, if pooler_config.pooling_type is not None else pooling_type,
......
...@@ -9,6 +9,7 @@ import itertools ...@@ -9,6 +9,7 @@ import itertools
import json import json
import math import math
import os import os
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
...@@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module, ...@@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__) logger = init_logger(__name__)
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_config = vllm_config.model_config model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config,
architectures=architectures)
signatures = inspect.signature(model_class.__init__) signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()] all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params: if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class # new-style model class
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix) return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as " msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class" "input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. " " registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html " "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.") "for the design and update the model class accordingly.")
logger.warning(msg) warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning( logger.warning(
"Trying to guess the arguments for old-style model class %s", "Trying to guess the arguments for old-style model class %s",
model_class, model_class,
...@@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load = {name for name, _ in model.named_parameters()} weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights( loaded_weights = model.load_weights(
self._get_all_weights(model_config, model)) self._get_all_weights(model_config, model))
# We only enable strict check for non-quantiized models # We only enable strict check for non-quantized models
# that have loaded weights tracking currently. # that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None: if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights weights_not_loaded = weights_to_load - loaded_weights
......
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
from typing import Tuple, Type from typing import Optional, Tuple, Type
import torch import torch
from torch import nn from torch import nn
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import as_embedding_model
@contextlib.contextmanager @contextlib.contextmanager
...@@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig,
architectures = getattr(model_config.hf_config, "architectures", []) *,
architectures: Optional[list[str]] = None,
) -> Tuple[Type[nn.Module], str]:
if architectures is None:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = [ mixtral_supported = [
...@@ -32,7 +38,11 @@ def get_model_architecture( ...@@ -32,7 +38,11 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures): and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
model_cls = as_embedding_model(model_cls)
return model_cls, arch
def get_architecture_class_name(model_config: ModelConfig) -> str: def get_architecture_class_name(model_config: ModelConfig) -> str:
......
from collections.abc import Iterable
from typing import Any, TypeVar
import torch
import torch.nn as nn
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
_T = TypeVar("_T", bound=type[nn.Module])
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_embedding_model(cls):
return cls
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForEmbedding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# These are not used in embedding models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or next(child.parameters(), None) is None
for name, child in self.named_children())
if model_is_only_param:
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
self.model.load_weights(weights)
return
# For most other models
if hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") \
.removesuffix("ChatModel") \
.removesuffix("LMHeadModel") + "ForEmbedding"
return ModelForEmbedding # type: ignore
...@@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) )
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
...@@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
...@@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
) )
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config) self.mlp1 = self._init_mlp1(config)
......
...@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
...@@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
extract_layer_index, is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix) self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
...@@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.sampler = get_sampler() self.sampler = get_sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix) return LlamaModel(vllm_config=vllm_config, prefix=prefix)
...@@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)
def sample(self, logits: torch.Tensor, def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
...@@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@property
def config(self):
return self.model.config
...@@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
...@@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata ...@@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext) InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
...@@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
vision_feature_layer = config.vision_feature_layer vision_feature_layer = config.vision_feature_layer
...@@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
# The same model class supports both language generation and embedding )
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
...@@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model")) hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment