Unverified Commit 27d7638b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 17617398
...@@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ...@@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP): ...@@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
prefix=maybe_prefix( prefix=maybe_prefix(
prefix, "model")) prefix, "model"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import flatten_bn, isin_list
class GLMVImagePixelInputs(TensorSchema): class GLMVImagePixelInputs(TensorSchema):
...@@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=[
self.config.boi_token_id,
self.config.pad_token_id,
self.config.eoi_token_id,
],
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, [
self.config.boi_token_id,
self.config.pad_token_id,
self.config.eoi_token_id,
]),
)
input_ids = None input_ids = None
hidden_states = self.transformer(input_ids, positions, hidden_states = self.transformer(input_ids, positions,
......
...@@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip2 import Blip2QFormerModel from .blip2 import Blip2QFormerModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, embed_multimodal, from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
init_vllm_registered_model, maybe_prefix)
### Audio Input ### Audio Input
...@@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration( ...@@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration(
# Split variable length features into a tuple # Split variable length features into a tuple
return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, self,
**kwargs: object, **kwargs: object,
...@@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration( ...@@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration(
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
return [] return []
return None
audio_features = self._process_audio_input(audio_input) audio_features = self._process_audio_input(audio_input)
return audio_features return audio_features
...@@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration( ...@@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute the merged LLM / audio embeddings.""" # This is to satisfy the type checker for each overload
if multimodal_embeddings is None \ if multimodal_embeddings is None or is_multimodal is None:
or len(multimodal_embeddings) == 0: return super().get_input_embeddings(input_ids)
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal( return super().get_input_embeddings(
input_ids, input_ids,
self.config.audio_token_index, multimodal_embeddings=multimodal_embeddings,
self.language_model.model.get_input_embeddings, is_multimodal=is_multimodal,
multimodal_embeddings, handle_oov_mm_token=handle_oov_mm_token,
) )
return inputs_embeds
def forward( def forward(
self, self,
...@@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration( ...@@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration(
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
audio_embeds = self.get_multimodal_embeddings(**kwargs) audio_embeds = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) inputs_embeds = self.get_input_embeddings(
input_ids,
audio_embeds,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None input_ids = None
model_output = self.language_model(input_ids, positions, model_output = self.language_model(input_ids, positions,
......
...@@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
moe.n_redundant_experts = self.num_redundant_experts moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map() moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors ...@@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
EOT = "<|endofturn|>" EOT = "<|endofturn|>"
...@@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, self,
**kwargs: Unpack[HCXVisionMultimodalInputs], **kwargs: Unpack[HCXVisionMultimodalInputs],
) -> Optional[MultiModalEmbeddings]: ) -> MultiModalEmbeddings:
multimodal_embeddings = list() multimodal_embeddings = list()
if kwargs.get("pixel_values_images") is not None: if kwargs.get("pixel_values_images") is not None:
...@@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings.append(_multimodal_embeddings_videos) multimodal_embeddings.append(_multimodal_embeddings_videos)
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
placeholder_token_id=[
self.config.image_token_id,
self.config.video_token_id,
],
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
multimodal_embeddings) input_ids,
multimodal_embeddings,
is_multimodal=isin_list(
input_ids,
[self.config.image_token_id, self.config.video_token_id]),
)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
......
...@@ -52,8 +52,7 @@ from .idefics2_vision_model import ( ...@@ -52,8 +52,7 @@ from .idefics2_vision_model import (
# yapf: enable # yapf: enable
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
merge_multimodal_embeddings)
class Idefics3ImagePixelInputs(TensorSchema): class Idefics3ImagePixelInputs(TensorSchema):
...@@ -539,10 +538,7 @@ class Idefics3Model(nn.Module): ...@@ -539,10 +538,7 @@ class Idefics3Model(nn.Module):
return image_hidden_states return image_hidden_states
def get_input_embeddings( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.text_model.get_input_embeddings(input_ids) return self.text_model.get_input_embeddings(input_ids)
def forward( def forward(
...@@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_id,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None input_ids = None
hidden_states = self.model.text_model(input_ids, hidden_states = self.model.text_model(input_ids,
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, MutableSequence from collections.abc import Iterable, Mapping, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional,
Union, overload, runtime_checkable) Protocol, Union, overload, runtime_checkable)
import numpy as np import numpy as np
import torch import torch
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.utils import supports_kw from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model from .interfaces_base import VllmModel, is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol): ...@@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> VllmModel:
""" """
Returns the underlying language model used for text generation. Returns the underlying language model used for text generation.
...@@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol): ...@@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol):
""" """
... ...
@overload
def get_input_embeddings(self, input_ids: Tensor) -> Tensor:
...
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings,
*,
is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor:
...
def _get_text_embeddings(
self,
input_ids: Tensor,
get_input_embeddings: Callable[[Tensor], Tensor],
*,
is_multimodal: Optional[Tensor],
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[Tensor] = None,
handle_oov_mm_token: bool = False,
) -> Tensor: ) -> Tensor:
""" """
Returns the input embeddings merged from the text embeddings from Apply token embeddings to `input_ids`.
input_ids and the multimodal embeddings generated from multimodal
kwargs. If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
""" """
... from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
@runtime_checkable @runtime_checkable
......
...@@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]): ...@@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]):
) -> None: ) -> None:
... ...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
"""Apply token embeddings to `input_ids`."""
...
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool: ...@@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
return supports_kw(model_init, "vllm_config") return supports_kw(model_init, "vllm_config")
def _check_vllm_model_get_input_embeddings(
model: Union[type[object], object]) -> bool:
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if not callable(model_get_input_embeddings):
logger.warning(
"The model (%s) is missing the `get_input_embeddings` method.",
model,
)
return False
return True
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
model_forward = getattr(model, "forward", None) model_forward = getattr(model, "forward", None)
if not callable(model_forward): if not callable(model_forward):
...@@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...@@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
def is_vllm_model( def is_vllm_model(
model: Union[type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model) return (_check_vllm_model_init(model)
and _check_vllm_model_get_input_embeddings(model)
and _check_vllm_model_forward(model))
@runtime_checkable @runtime_checkable
......
...@@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, isin_list, maybe_prefix)
merge_multimodal_embeddings)
class InternS1MultiModalProjector(nn.Module): class InternS1MultiModalProjector(nn.Module):
...@@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None and len(
if multimodal_embeddings is not None \ multimodal_embeddings) > 0:
and len(multimodal_embeddings) != 0:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids, input_ids,
inputs_embeds, multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings, is_multimodal=is_multimodal,
context_token_ids, handle_oov_mm_token=handle_oov_mm_token,
) )
return inputs_embeds
def forward( def forward(
self, self,
...@@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, context_token_ids),
)
input_ids = None input_ids = None
forward_kwargs = { forward_kwargs = {
......
...@@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) isin_list, maybe_prefix)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None and len(
if multimodal_embeddings is not None \ multimodal_embeddings) > 0:
and len(multimodal_embeddings) != 0:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids, input_ids,
inputs_embeds, multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings, is_multimodal=is_multimodal,
context_token_ids, handle_oov_mm_token=handle_oov_mm_token,
) )
return inputs_embeds
def forward( def forward(
self, self,
...@@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, context_token_ids),
)
input_ids = None input_ids = None
forward_kwargs = { forward_kwargs = {
......
...@@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module): ...@@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module):
multimodal_embeddings += video_embeddings multimodal_embeddings += video_embeddings
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[
self.config.image_token_id,
self.config.video_token_id,
],
)
return inputs_embeds
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -66,7 +66,6 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model ...@@ -66,7 +66,6 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import (SupportsMultiModal, from vllm.model_executor.models.interfaces import (SupportsMultiModal,
SupportsPP) SupportsPP)
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems, NestedTensors)
...@@ -424,26 +423,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -424,26 +423,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.media_placeholder_token_id)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -462,14 +441,12 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -462,14 +441,12 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None: if image_input is None:
inputs_embeds = None inputs_embeds = None
else: else:
inputs_embeds = self.get_input_embeddings(input_ids)
image_embeds = self._process_image_input(image_input) image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = self.get_input_embeddings(
input_ids, input_ids,
inputs_embeds,
image_embeds, image_embeds,
placeholder_token_id=self.config. is_multimodal=input_ids ==
media_placeholder_token_id, self.config.media_placeholder_token_id,
) )
input_ids = None input_ids = None
......
...@@ -522,6 +522,9 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -522,6 +522,9 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -37,9 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -37,9 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM) Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings from .interfaces import SupportsMultiModal
from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -79,10 +79,7 @@ class LlamaModel(nn.Module): ...@@ -79,10 +79,7 @@ class LlamaModel(nn.Module):
self.norm = RMSNorm(self.config.hidden_size, self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps) eps=self.config.rms_norm_eps)
def get_input_embeddings( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -194,6 +191,11 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ...@@ -194,6 +191,11 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
self.logits_processor = LogitsProcessor(self.config.vocab_size, self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale) scale=logit_scale)
def get_language_model(self) -> torch.nn.Module:
return self.model
get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -220,20 +222,3 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ...@@ -220,20 +222,3 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
skip_prefixes=(["lm_head."]), skip_prefixes=(["lm_head."]),
) )
loader.load_weights(map(transform, weights)) loader.load_weights(map(transform, weights))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
...@@ -73,6 +73,9 @@ class LlamaModel(nn.Module): ...@@ -73,6 +73,9 @@ class LlamaModel(nn.Module):
self.config.hidden_size, self.config.hidden_size,
bias=False) bias=False)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -149,6 +152,9 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): ...@@ -149,6 +152,9 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
self.logits_processor = LogitsProcessor(self.config.vocab_size, self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale) scale=logit_scale)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.llama import (LlamaDecoderLayer, from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM) LlamaForCausalLM)
...@@ -144,10 +143,7 @@ class LlamaModel(nn.Module): ...@@ -144,10 +143,7 @@ class LlamaModel(nn.Module):
eps=self.config.rms_norm_eps, eps=self.config.rms_norm_eps,
) )
def get_input_embeddings( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -239,6 +235,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -239,6 +235,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False, requires_grad=False,
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -302,11 +301,3 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -302,11 +301,3 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs=skip_substrs, skip_substrs=skip_substrs,
) )
loader.load_weights(model_weights.items()) loader.load_weights(model_weights.items())
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
return inputs_embeds
...@@ -41,8 +41,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP ...@@ -41,8 +41,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
...@@ -676,22 +675,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -676,22 +675,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -744,8 +727,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -744,8 +727,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
......
...@@ -25,8 +25,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, ...@@ -25,8 +25,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava) LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
flatten_bn, init_vllm_registered_model, maybe_prefix) init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TensorSchema): class LlavaNextImagePixelInputs(TensorSchema):
...@@ -474,19 +474,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -474,19 +474,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
if multimodal_embeddings is None \ return super().get_input_embeddings(
or len(multimodal_embeddings) == 0:
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(
input_ids, input_ids,
self.config.image_token_index, multimodal_embeddings=multimodal_embeddings,
self.language_model.model.get_input_embeddings, is_multimodal=is_multimodal,
multimodal_embeddings, handle_oov_mm_token=handle_oov_mm_token,
) )
return inputs_embeds
def forward( def forward(
self, self,
...@@ -549,8 +551,11 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens]. ...@@ -549,8 +551,11 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
......
...@@ -30,8 +30,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP ...@@ -30,8 +30,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
...@@ -415,19 +414,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -415,19 +414,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_video_pixels(video_input) vision_embeddings = self._process_video_pixels(video_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.video_token_index)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -449,8 +435,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -449,8 +435,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.video_token_index,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
......
...@@ -850,19 +850,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -850,19 +850,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index])
return inputs_embeds
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment