Unverified Commit 27d7638b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 17617398
......@@ -66,35 +66,12 @@ Further update the model as follows:
!!! important
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
??? code
```python
from .utils import merge_multimodal_embeddings
class YourModelForImage2Seq(nn.Module):
...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index)
!!! note
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
return inputs_embeds
```
You may override this method if additional logic is required for your model when merging embeddings.
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
......
......@@ -509,9 +509,14 @@ class ModelConfig:
else: # task == "auto"
pass
else:
debug_info = {
"architectures": architectures,
"is_generative_model": is_generative_model,
"is_pooling_model": is_pooling_model,
}
raise AssertionError("The model should be a generative or "
"pooling model when task is set to "
f"{self.task!r}.")
f"{self.task!r}. Found: {debug_info}")
self.runner = runner
self.convert = convert
......
......@@ -38,8 +38,7 @@ from .idefics2_vision_model import (
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix,
merge_multimodal_embeddings)
is_pp_missing_parameter, maybe_prefix)
class AriaImagePixelInputs(TensorSchema):
......@@ -605,19 +604,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
multimodal_embeddings = self._process_image_input(image_input)
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -628,10 +614,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model(
......
......@@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
class AyaVisionImagePixelInputs(TensorSchema):
......@@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input, **kwargs)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(
......
......@@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant):
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
prefix=maybe_prefix(prefix, "model"))
self.pooler = self._build_pooler(pooler_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
),
})
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
......@@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module):
Pooler.for_encode(pooler_config),
})
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
......
......@@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant):
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(self.config) if add_pooling_layer else None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
loaded_params = loader.load_weights(weights)
return loaded_params
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.get_input_embeddings(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
......
......@@ -27,7 +27,7 @@ from .blip import BlipVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
......@@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
logger = init_logger(__name__)
......@@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
image_token_id = self.model.vocabulary_mapping.image_token_id
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == image_token_id,
)
input_ids = None
hidden_states = self.model(input_ids,
......
......@@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module):
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
class Cohere2VisionImagePixelInputs(TensorSchema):
......@@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input, **kwargs)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.language_model.model(
......
......@@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module):
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
prefix=maybe_prefix(
prefix, "model"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
# The image token id may be various
_IMAGE_TOKEN = "<image>"
......@@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
model_config = vllm_config.model_config
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
self.vision = self._init_vision_module(self.vision_config,
quant_config,
......@@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
......@@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.image_token_id,
)
input_ids = None
hidden_states = self.language_model(input_ids,
......
......@@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
Qwen2VLProcessingInfo)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
merge_multimodal_embeddings)
maybe_prefix)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
......@@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
......@@ -830,15 +813,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None and kwargs.get("pixel_values") is not None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
inputs_embeds = None
else:
assert input_ids is not None
inputs_embeds = self.get_multimodal_embeddings(
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
input_ids,
image_input=image_input,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
......
......@@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
......@@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is None:
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds,
multimodal_embeddings,
[self.config.im_patch_id])
return inputs_embeds
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
......
......@@ -116,6 +116,9 @@ class ErnieMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......@@ -160,6 +163,9 @@ class ErnieMTP(nn.Module, SupportsPP):
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
......@@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
_IMAGE_TOKEN_ID,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
)
input_ids = None
hidden_states = self.language_model(
......
......@@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
logger = init_logger(__name__)
......@@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
......@@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
if (vision_embeddings is not None) and len(vision_embeddings) != 0:
kwargs = self.prepare_attn_masks(
input_ids,
......
......@@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
if input_ids is not None:
......@@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
per_layer_inputs)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
# NOTE: this order of processing mm items is important
[self.config.image_token_id, self.config.audio_token_id])
return inputs_embeds
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(self,
input_ids: torch.Tensor,
......
......@@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0
and all(embed.numel() > 0 for embed in multimodal_embeddings)):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id],
)
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment