"tests/kernels/test_blocksparse_attention.py" did not exist on "fbd80ad4092c4bc48ce672f0435c1d1362aee052"
Unverified Commit 13370712 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Replace embedding models with pooling adapter (#10769)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 7e4bbda5
......@@ -334,7 +334,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model
- pytest -v -s models/embedding/vision_language -m core_model
- label: Language Models Test (Extended) # 50min
optional: true
......@@ -346,7 +345,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- label: Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd]
......@@ -359,6 +357,7 @@ steps:
commands:
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model
......@@ -376,6 +375,7 @@ steps:
# https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
......
......@@ -357,7 +357,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
......@@ -378,6 +378,10 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
......@@ -397,12 +401,21 @@ Reward Modeling
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- ✅︎
- ✅︎
.. important::
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
.. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
......
......@@ -263,7 +263,6 @@ class HfRunner:
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
......
......@@ -4,6 +4,8 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
import pytest
from vllm.config import PoolerConfig
from ..utils import check_embeddings_close
......@@ -33,6 +35,9 @@ def test_models(
dtype: str,
) -> None:
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN")
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
......
......@@ -6,11 +6,8 @@ import torch.cuda
from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model,
supports_multimodal)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
from vllm.model_executor.models.adapters import as_embedding_model
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
ModelRegistry)
......@@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
if model_arch in _SPECULATIVE_DECODING_MODELS:
pass # Ignore these models which do not have a unified format
else:
assert is_text_generation_model(model_cls) is (
model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS)
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
assert is_embedding_model(model_cls) is (model_arch
in embedding_models)
assert supports_multimodal(model_cls) is (model_arch
in _MULTIMODAL_MODELS)
return # Ignore these models which do not have a unified format
if (model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS):
assert is_text_generation_model(model_cls)
# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model)
if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)
@fork_new_process_for_each_test
......
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
class MyGemma2Embedding(Gemma2EmbeddingModel):
class MyGemma2Embedding(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -18,7 +39,7 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(
hidden_states = self.model(
input_ids,
positions,
kv_caches,
......@@ -32,3 +53,17 @@ class MyGemma2Embedding(Gemma2EmbeddingModel):
# Return all-zero embeddings
return torch.zeros_like(hidden_states)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
return self.model.load_weights(weights)
......@@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
@pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"),
("intfloat/e5-mistral-7b-instruct", "generate"),
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
......
......@@ -370,6 +370,31 @@ class ModelConfig:
selected_task = next(iter(supported_tasks_lst))
if len(supported_tasks) > 1:
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
# Hardcode the models that are exceptions
("AquilaModel", "generate"),
("ChatGLMModel", "generate"),
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embedding"),
("RewardModel", "embedding"),
("ForSequenceClassification", "embedding"),
]
info, arch = ModelRegistry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
selected_task = pref_task
break
else:
if (arch.endswith("Model")
and info.architecture.endswith("ForCausalLM")
and "embedding" in supported_tasks):
selected_task = "embedding"
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
......
......@@ -11,8 +11,8 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
print_warning_once, resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
......@@ -136,12 +136,12 @@ class InputRegistry:
"""
def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
self._dummy_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._dummy_encoder_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._input_processors_by_model_type = \
ClassRegistry[nn.Module, InputProcessor]()
def _default_dummy_data_factory(
self,
......
......@@ -60,9 +60,7 @@ class Pooler(nn.Module):
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> Optional["Pooler"]:
if pooler_config is None:
return None
) -> "Pooler":
return cls(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
......
......@@ -9,6 +9,7 @@ import itertools
import json
import math
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
......@@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__)
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
model_class, _ = get_model_architecture(model_config,
architectures=architectures)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.")
logger.warning(msg)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
......@@ -356,7 +366,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
# We only enable strict check for non-quantiized models
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
......
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
from typing import Optional, Tuple, Type
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import as_embedding_model
@contextlib.contextmanager
......@@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
model_config: ModelConfig,
*,
architectures: Optional[list[str]] = None,
) -> Tuple[Type[nn.Module], str]:
if architectures is None:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
......@@ -32,7 +38,11 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
model_cls = as_embedding_model(model_cls)
return model_cls, arch
def get_architecture_class_name(model_config: ModelConfig) -> str:
......
from collections.abc import Iterable
from typing import Any, TypeVar
import torch
import torch.nn as nn
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
_T = TypeVar("_T", bound=type[nn.Module])
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_embedding_model(cls):
return cls
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForEmbedding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# These are not used in embedding models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or next(child.parameters(), None) is None
for name, child in self.named_children())
if model_is_only_param:
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
self.model.load_weights(weights)
return
# For most other models
if hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") \
.removesuffix("ChatModel") \
.removesuffix("LMHeadModel") + "ForEmbedding"
return ModelForEmbedding # type: ignore
......@@ -512,9 +512,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
......
......@@ -30,19 +30,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -455,55 +453,3 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
......@@ -474,9 +474,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config)
......
......@@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
......@@ -47,14 +46,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
extract_layer_index, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -511,11 +509,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
......@@ -544,13 +543,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
......@@ -581,14 +576,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
......@@ -639,78 +626,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item])
return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@property
def config(self):
return self.model.config
......@@ -319,9 +319,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
......
......@@ -14,13 +14,11 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
......@@ -286,7 +284,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
vision_feature_layer = config.vision_feature_layer
......@@ -321,17 +318,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
......@@ -678,13 +669,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
......
......@@ -275,9 +275,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment