Unverified Commit 5089fd74 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 logic from `get_input_embeddings` interface (#25242)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a3d087ad
...@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors ...@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
EOT = "<|endofturn|>" EOT = "<|endofturn|>"
...@@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (kwargs.get("pixel_values_images") is not None if multimodal_embeddings is not None \
or kwargs.get("pixel_values_videos") and len(multimodal_embeddings) != 0:
is not None): # v0 compatibility inputs_embeds = merge_multimodal_embeddings(
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) input_ids,
if multimodal_embeddings is not None: inputs_embeds,
multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0) multimodal_embeddings,
_mask_image = input_ids == self.config.image_token_id placeholder_token_id=[
_mask_video = input_ids == self.config.video_token_id self.config.image_token_id,
assert _mask_image.sum() + _mask_video.sum() == len( self.config.video_token_id,
multimodal_embeddings) ],
)
if multimodal_embeddings.dtype != inputs_embeds.dtype:
multimodal_embeddings = multimodal_embeddings.to(
dtype=inputs_embeds.dtype)
if multimodal_embeddings.device != inputs_embeds.device:
multimodal_embeddings = multimodal_embeddings.to(
device=inputs_embeds.device)
if _mask_image.sum() > 0:
inputs_embeds[
_mask_image] = multimodal_embeddings[:sum(_mask_image)]
if _mask_video.sum() > 0:
inputs_embeds[_mask_video] = multimodal_embeddings[
-sum(_mask_video):]
return inputs_embeds return inputs_embeds
def forward( def forward(
...@@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids=input_ids, multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
**kwargs) inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
......
...@@ -23,7 +23,6 @@ from vllm.utils import supports_kw ...@@ -23,7 +23,6 @@ from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model from .interfaces_base import is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -97,33 +96,10 @@ class SupportsMultiModal(Protocol): ...@@ -97,33 +96,10 @@ class SupportsMultiModal(Protocol):
""" """
... ...
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> Tensor:
...
# TODO: Remove this overload once v0 is deprecated
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> Tensor:
...
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
# Only necessary so that the v0 overload is valid
# TODO: Remove attn_metadata once v0 is deprecated
attn_metadata: Optional["AttentionMetadata"] = None,
) -> Tensor: ) -> Tensor:
""" """
Returns the input embeddings merged from the text embeddings from Returns the input embeddings merged from the text embeddings from
......
...@@ -13,9 +13,7 @@ from transformers import BatchFeature, ProcessorMixin ...@@ -13,9 +13,7 @@ from transformers import BatchFeature, ProcessorMixin
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader import DefaultModelLoader
...@@ -37,8 +35,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, ...@@ -37,8 +35,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings, merge_multimodal_embeddings)
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16 _MAX_ENCODER_BATCH_SIZE = 16
...@@ -568,17 +565,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -568,17 +565,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
safe_input_ids) safe_input_ids)
if multimodal_embeddings is not None and len( if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0: multimodal_embeddings) > 0:
inputs_embeds = merge_multimodal_embeddings(
# TODO(ywang96): remove this block after v0 is deprecated. input_ids, inputs_embeds, multimodal_embeddings,
if not envs.VLLM_USE_V1: self.config.audio_token_index)
attn_metadata = get_forward_context().attn_metadata
merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
else:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds return inputs_embeds
def forward(self, def forward(self,
......
...@@ -15,7 +15,7 @@ import vllm.envs as envs ...@@ -15,7 +15,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available) is_uva_available)
...@@ -389,22 +389,6 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: ...@@ -389,22 +389,6 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings) _embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map .
Note:
This updates ``inputs_embeds`` in place.
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
placeholder_map.src].to(dtype=inputs_embeds.dtype)
return inputs_embeds
def _merge_multimodal_embeddings( def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor, is_multimodal: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment