Unverified Commit 27d7638b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 17617398
......@@ -54,8 +54,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.midashenglm import DashengConfig
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
_Tuple2 = Union[int, tuple[int, int], Sequence[int]]
......@@ -744,21 +743,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
return []
return self._process_audio_input(audio_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.decoder.get_input_embeddings(input_ids)
if multimodal_embeddings and len(multimodal_embeddings) > 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.audio_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -771,8 +755,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = None
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_id,
)
input_ids = None
return self.decoder.model(input_ids,
......
......@@ -117,6 +117,9 @@ class MiMoMultiTokenPredictor(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -158,6 +161,9 @@ class MiMoMTP(nn.Module):
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -71,8 +71,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, flatten_bn, isin_list, maybe_prefix
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
......@@ -1144,23 +1143,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return self._process_multimodal_inputs(modalities)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert len(self.mm_token_ids) > 0
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
list(self.mm_token_ids),
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -1178,8 +1160,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, list(self.mm_token_ids)),
)
input_ids = None
hidden_states = self.llm.model(
......
......@@ -592,10 +592,7 @@ class MiniMaxText01Model(nn.Module):
dtype=torch.long)
minimax_cache_tensors[:, slots_tensor, ...] = 0
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(self,
......@@ -687,10 +684,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
batch_size)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
......
......@@ -28,7 +28,7 @@ from .llava_next import LlavaNextProcessingInfo
from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
class MiniMaxVL01ImagePixelInputs(TensorSchema):
......@@ -218,22 +218,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
......@@ -403,8 +387,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -38,8 +38,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import get_vision_encoder_info
......@@ -524,22 +523,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -592,8 +575,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
......@@ -56,8 +56,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .vision import run_dp_sharded_vision_model
......@@ -813,24 +812,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -846,8 +827,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# this condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
return self.language_model(input_ids, positions, intermediate_tensors,
......
......@@ -43,6 +43,9 @@ class ModernBertEmbeddings(nn.Module):
eps=config.layer_norm_eps,
bias=config.norm_bias)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -220,6 +223,9 @@ class ModernBertModel(nn.Module):
eps=config.norm_eps,
bias=config.norm_bias)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
......@@ -333,6 +339,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
),
})
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = []
......
......@@ -58,7 +58,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
......@@ -819,10 +819,7 @@ class MolmoModel(nn.Module, SupportsQuant):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -1481,24 +1478,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_patch_id is not None
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.img_patch_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.LongTensor,
......@@ -1515,8 +1494,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.img_patch_id,
)
input_ids = None
hidden_states = self.model(input_ids,
......
......@@ -35,8 +35,7 @@ from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.utils import (flatten_bn,
init_vllm_registered_model,
maybe_prefix,
merge_multimodal_embeddings)
isin_list, maybe_prefix)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, MultiModalKwargsItems,
......@@ -1096,8 +1095,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
return modalities
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
# Validate the multimodal input keyword arguments
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities is None:
......@@ -1121,30 +1120,6 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
assert len(context_token_ids) >= 1
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
......@@ -1163,9 +1138,17 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
if token_id is not None
]
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=isin_list(input_ids, context_token_ids),
)
input_ids = None
hidden_states = self.language_model(
......
......@@ -38,7 +38,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
IMG_START = '<img>'
IMG_END = '</img>'
......@@ -576,20 +576,24 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [self.img_context_token_id]
assert len(context_token_ids) >= 1
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
context_token_ids,
)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
......@@ -608,8 +612,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.img_context_token_id,
)
input_ids = None
forward_kwargs = {
......
......@@ -295,6 +295,9 @@ class Olmo2Model(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"],
self.config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -408,6 +411,9 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -48,7 +48,6 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
# Cannot find the following number from hf config.
IMAGE_TOKEN = "<image>"
......@@ -501,19 +500,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
return image_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_pad_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -529,8 +515,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.image_pad_token_id,
)
input_ids = None
# up until here we have an inputs_embeds 100% numerical identity
......
......@@ -585,17 +585,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
tmp = torch.concat(multimodal_embeddings, dim=0)
inputs_embeds[input_ids == self.image_pad_token_id] = tmp
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -612,8 +601,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.image_pad_token_id,
)
input_ids = None
# up until here we have a inputs_embeds 100% numerical identity
......
......@@ -26,8 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import get_vision_encoder_info
logger = init_logger(__name__)
......@@ -362,19 +361,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
......@@ -388,8 +374,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -51,9 +51,9 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper,
_merge_multimodal_embeddings, flatten_bn,
init_vllm_registered_model, maybe_prefix)
logger = init_logger(__name__)
......@@ -643,14 +643,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
return inputs_embeds
inputs_embeds = self._get_text_embeddings(
input_ids,
self.embed_tokens,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def forward(self,
input_ids: torch.Tensor,
......@@ -666,8 +683,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
# condition is for v0 compatibility
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=self.image_token_id,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -1342,12 +1342,12 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
image_attention_mask)
return image_embeds
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
return []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
......@@ -1371,18 +1371,6 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
......
......@@ -1151,7 +1151,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
......@@ -1175,19 +1174,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.embed_tokens(input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
......
......@@ -50,8 +50,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
try:
......@@ -433,22 +432,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.vision_args.image_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -465,8 +448,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.vision_args.image_token_id,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -865,24 +865,26 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings += audio_embeddings
return multimodal_embeddings
# TODO (ywang96): support overlapping modality embeddings so that
# `use_audio_in_video` will work on V1.
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
# TODO (ywang96): support overlapping modality embeddings so that
# `use_audio_in_video` will work on V1.
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.config.image_token_index,
self.config.video_token_index,
self.config.audio_token_index
])
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def get_multimodal_embeddings_v0(
self, **kwargs: object) -> Optional[NestedTensors]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment