Unverified Commit 27d7638b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 17617398
...@@ -66,35 +66,12 @@ Further update the model as follows: ...@@ -66,35 +66,12 @@ Further update the model as follows:
!!! important !!! important
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. !!! note
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
??? code [PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
```python
from .utils import merge_multimodal_embeddings
class YourModelForImage2Seq(nn.Module):
...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index)
return inputs_embeds You may override this method if additional logic is required for your model when merging embeddings.
```
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model. - Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
......
...@@ -509,9 +509,14 @@ class ModelConfig: ...@@ -509,9 +509,14 @@ class ModelConfig:
else: # task == "auto" else: # task == "auto"
pass pass
else: else:
debug_info = {
"architectures": architectures,
"is_generative_model": is_generative_model,
"is_pooling_model": is_pooling_model,
}
raise AssertionError("The model should be a generative or " raise AssertionError("The model should be a generative or "
"pooling model when task is set to " "pooling model when task is set to "
f"{self.task!r}.") f"{self.task!r}. Found: {debug_info}")
self.runner = runner self.runner = runner
self.convert = convert self.convert = convert
......
...@@ -38,8 +38,7 @@ from .idefics2_vision_model import ( ...@@ -38,8 +38,7 @@ from .idefics2_vision_model import (
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix, is_pp_missing_parameter, maybe_prefix)
merge_multimodal_embeddings)
class AriaImagePixelInputs(TensorSchema): class AriaImagePixelInputs(TensorSchema):
...@@ -605,19 +604,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -605,19 +604,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
multimodal_embeddings = self._process_image_input(image_input) multimodal_embeddings = self._process_image_input(image_input)
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -628,10 +614,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -628,10 +614,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is None: if inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds` inputs_embeds = self.get_input_embeddings(
# to make sure the computation graph is consistent input_ids,
inputs_embeds = self.get_input_embeddings(input_ids, multimodal_embeddings,
multimodal_embeddings) is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None input_ids = None
hidden_states = self.language_model( hidden_states = self.language_model(
......
...@@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
class AyaVisionImagePixelInputs(TensorSchema): class AyaVisionImagePixelInputs(TensorSchema):
...@@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input, **kwargs) return self._process_image_input(image_input, **kwargs)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
......
...@@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant):
self.encoder = BertEncoder(vllm_config=vllm_config, self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): ...@@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.pooler = self._build_pooler(pooler_config) self.pooler = self._build_pooler(pooler_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
), ),
}) })
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights) loaded_params = loader.load_weights(weights)
...@@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module): ...@@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module):
Pooler.for_encode(pooler_config), Pooler.for_encode(pooler_config),
}) })
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights) loaded_params = loader.load_weights(weights)
......
...@@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant): ...@@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant):
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(self.config) if add_pooling_layer else None self.pooler = BertPooler(self.config) if add_pooling_layer else None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
loaded_params = loader.load_weights(weights) loaded_params = loader.load_weights(weights)
return loaded_params return loaded_params
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
......
...@@ -27,7 +27,7 @@ from .blip import BlipVisionModel ...@@ -27,7 +27,7 @@ from .blip import BlipVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant) SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix)
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo # defined on the HuggingFace repo
...@@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
......
...@@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, ...@@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant) SupportsQuant)
from .utils import (flatten_bn, is_pp_missing_parameter, from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self.model.get_input_embeddings(image_tokens) vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, image_token_id = self.model.vocabulary_mapping.image_token_id
vision_embeddings) inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == image_token_id,
)
input_ids = None input_ids = None
hidden_states = self.model(input_ids, hidden_states = self.model(input_ids,
......
...@@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module): ...@@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
class Cohere2VisionImagePixelInputs(TensorSchema): class Cohere2VisionImagePixelInputs(TensorSchema):
...@@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self._process_image_input(image_input, **kwargs) return self._process_image_input(image_input, **kwargs)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_id,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
......
...@@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module): ...@@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module):
self.norm = RMSNorm(self.config.hidden_size, self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps) eps=self.config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): ...@@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
self.logits_processor = LogitsProcessor(self.config.vocab_size, self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale) scale=logit_scale)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
prefix=maybe_prefix( prefix=maybe_prefix(
prefix, "model")) prefix, "model"))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
# The image token id may be various # The image token id may be various
_IMAGE_TOKEN = "<image>" _IMAGE_TOKEN = "<image>"
...@@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
model_config = vllm_config.model_config model_config = vllm_config.model_config
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
self.vision = self._init_vision_module(self.vision_config, self.vision = self._init_vision_module(self.vision_config,
quant_config, quant_config,
...@@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
return inputs_embeds
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility # condition is for v0 compatibility
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.image_token_id,
)
input_ids = None input_ids = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model(input_ids,
......
...@@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, ...@@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
Qwen2VLProcessingInfo) Qwen2VLProcessingInfo)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix)
merge_multimodal_embeddings)
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.inputs import MultiModalDataDict
...@@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_id,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -830,15 +813,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -830,15 +813,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None and kwargs.get("pixel_values") is not None: elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
if image_input is None: inputs_embeds = self.get_input_embeddings(
inputs_embeds = None
else:
assert input_ids is not None
inputs_embeds = self.get_multimodal_embeddings(
input_ids, input_ids,
image_input=image_input, vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
) )
input_ids = None input_ids = None
......
...@@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors ...@@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
merge_multimodal_embeddings)
from .vision import get_vit_attn_backend from .vision import get_vit_attn_backend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
if multimodal_embeddings is None: return super().get_input_embeddings(input_ids)
return inputs_embeds
self._set_visual_token_mask(input_ids) return super().get_input_embeddings(
inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, input_ids,
multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
[self.config.im_patch_id]) is_multimodal=is_multimodal,
return inputs_embeds handle_oov_mm_token=handle_oov_mm_token,
)
def forward( def forward(
self, self,
......
...@@ -116,6 +116,9 @@ class ErnieMultiTokenPredictor(nn.Module): ...@@ -116,6 +116,9 @@ class ErnieMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -160,6 +163,9 @@ class ErnieMTP(nn.Module, SupportsPP): ...@@ -160,6 +163,9 @@ class ErnieMTP(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors ...@@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
...@@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
_IMAGE_TOKEN_ID,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
)
input_ids = None input_ids = None
hidden_states = self.language_model( hidden_states = self.language_model(
......
...@@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, ...@@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
if (vision_embeddings is not None) and len(vision_embeddings) != 0: if (vision_embeddings is not None) and len(vision_embeddings) != 0:
kwargs = self.prepare_attn_masks( kwargs = self.prepare_attn_masks(
input_ids, input_ids,
......
...@@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds. # them here, as the model forward has only access to the input_embeds.
if input_ids is not None: if input_ids is not None:
...@@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_( self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
per_layer_inputs) per_layer_inputs)
if multimodal_embeddings is not None \ # This is to satisfy the type checker for each overload
and len(multimodal_embeddings) != 0: if multimodal_embeddings is None or is_multimodal is None:
inputs_embeds = merge_multimodal_embeddings( return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids, input_ids,
inputs_embeds, multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings, is_multimodal=is_multimodal,
# NOTE: this order of processing mm items is important handle_oov_mm_token=handle_oov_mm_token,
[self.config.image_token_id, self.config.audio_token_id]) )
return inputs_embeds
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings += video_embeddings multimodal_embeddings += video_embeddings
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0
and all(embed.numel() > 0 for embed in multimodal_embeddings)):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id],
)
return inputs_embeds
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment