Unverified Commit 27d7638b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 17617398
...@@ -1365,19 +1365,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1365,19 +1365,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings += video_embeddings multimodal_embeddings += video_embeddings
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -49,8 +49,7 @@ from vllm.sequence import IntermediateTensors ...@@ -49,8 +49,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
maybe_prefix, merge_multimodal_embeddings)
# # === Audio Inputs === # # # === Audio Inputs === #
...@@ -438,19 +437,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -438,19 +437,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
masked_audio_features = self._process_audio_input(audio_input) masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features return masked_audio_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -467,8 +453,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -467,8 +453,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
multimodal_embeddings) input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
......
...@@ -1459,19 +1459,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1459,19 +1459,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -79,7 +79,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention, ...@@ -79,7 +79,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention,
from .qwen2_vl import Qwen2VLProcessingInfo from .qwen2_vl import Qwen2VLProcessingInfo
from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings) _merge_multimodal_embeddings, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1324,17 +1325,22 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1324,17 +1325,22 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings return multimodal_embeddings
def _compute_deepstack_embeds( def _compute_deepstack_embeds(
self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, self,
multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor: inputs_embeds: torch.Tensor,
visual_lens = [ multimodal_embeddings: MultiModalEmbeddings,
x.shape[0] if isinstance(x, torch.Tensor) else len(x) is_multimodal: torch.Tensor,
for x in multimodal_embeddings ) -> tuple[torch.Tensor, MultiModalEmbeddings]:
] visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 (
multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim], multimodal_embeddings_main,
dim=-1) multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)
multimodal_embeddings = torch.split(multimodal_embeddings_main, multimodal_embeddings = torch.split(multimodal_embeddings_main,
visual_lens, visual_lens,
...@@ -1346,39 +1352,62 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1346,39 +1352,62 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds.size(0), inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1)) self.deepstack_num_level * inputs_embeds.size(1))
deepstack_input_embeds = merge_multimodal_embeddings( deepstack_input_embeds = _merge_multimodal_embeddings(
input_ids, inputs_embeds=deepstack_input_embeds,
deepstack_input_embeds, multimodal_embeddings=multimodal_embeddings_multiscale,
multimodal_embeddings_multiscale, is_multimodal=is_multimodal,
placeholder_token_id=[
self.config.image_token_id, self.config.video_token_id
],
) )
deepstack_input_embeds = deepstack_input_embeds.view( deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
return deepstack_input_embeds, multimodal_embeddings return deepstack_input_embeds, multimodal_embeddings
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
deepstack_input_embeds = None inputs_embeds = self._get_text_embeddings(
inputs_embeds = self.language_model.get_input_embeddings(input_ids) input_ids,
if multimodal_embeddings is not None: self.language_model.get_input_embeddings,
if self.use_deepstack: is_multimodal=is_multimodal,
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 handle_oov_mm_token=handle_oov_mm_token,
input_ids, inputs_embeds, multimodal_embeddings) )
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
[self.config.image_token_id, self.config.video_token_id]) return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
if self.use_deepstack: if self.use_deepstack:
if deepstack_input_embeds is None: (
deepstack_input_embeds = torch.zeros_like( deepstack_input_embeds,
inputs_embeds).unsqueeze(0).repeat( multimodal_embeddings,
self.deepstack_num_level, 1, 1).contiguous() ) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
if deepstack_input_embeds is not None:
deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze(
0).repeat(self.deepstack_num_level, 1, 1).contiguous()
self._set_deepstack_input_embeds(deepstack_input_embeds) self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds return inputs_embeds
......
...@@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .qwen import QWenBaseModel, QWenModel from .qwen import QWenBaseModel, QWenModel
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import flatten_bn
class QwenImagePixelInputs(TensorSchema): class QwenImagePixelInputs(TensorSchema):
...@@ -756,21 +756,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -756,21 +756,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.transformer.visual.image_pad_id)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -786,8 +771,12 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -786,8 +771,12 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids ==
self.transformer.visual.image_pad_id,
)
input_ids = None input_ids = None
hidden_states = self.transformer(input_ids, positions, hidden_states = self.transformer(input_ids, positions,
......
...@@ -218,6 +218,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -218,6 +218,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.roberta.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
......
...@@ -38,7 +38,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -38,7 +38,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -842,19 +842,24 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -842,19 +842,24 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None and len(
if multimodal_embeddings is not None \ multimodal_embeddings) > 0:
and len(multimodal_embeddings) != 0:
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, # This is to satisfy the type checker for each overload
inputs_embeds, if multimodal_embeddings is None or is_multimodal is None:
multimodal_embeddings, return super().get_input_embeddings(input_ids)
self.img_context_token_id,
) return super().get_input_embeddings(
return inputs_embeds input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward( def forward(
self, self,
...@@ -873,8 +878,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -873,8 +878,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.img_context_token_id,
)
input_ids = None input_ids = None
forward_kwargs = { forward_kwargs = {
......
...@@ -483,6 +483,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -483,6 +483,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
......
...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
...@@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer ...@@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
from .vision import run_dp_sharded_vision_model from .vision import run_dp_sharded_vision_model
...@@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1 else cur_feature[0]) 1 else cur_feature[0])
return merged_image_features return merged_image_features
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
...@@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is None: # This is to satisfy the type checker for each overload
inputs_embeds = self.language_model.model.get_input_embeddings( if multimodal_embeddings is None or is_multimodal is None:
input_ids) return super().get_input_embeddings(input_ids)
else:
is_text = input_ids != self.config.image_token_id return super().get_input_embeddings(
text_ids = input_ids[is_text] input_ids,
text_embeds = self.language_model.model.get_input_embeddings( multimodal_embeddings=multimodal_embeddings,
text_ids) is_multimodal=is_multimodal,
inputs_embeds = torch.empty(input_ids.shape[0], handle_oov_mm_token=handle_oov_mm_token,
text_embeds.shape[-1], )
dtype=text_embeds.dtype,
device=text_embeds.device)
inputs_embeds[is_text] = text_embeds
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
return inputs_embeds
def forward( def forward(
self, self,
...@@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds` inputs_embeds = self.get_input_embeddings(
# to make sure the computation graph is consistent input_ids,
inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings,
vision_embeddings) is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None input_ids = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model(input_ids,
......
...@@ -40,7 +40,7 @@ from .clip import CLIPVisionModel ...@@ -40,7 +40,7 @@ from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix)
from .vision import VisionEncoderInfo, get_vision_encoder_info from .vision import VisionEncoderInfo, get_vision_encoder_info
...@@ -589,22 +589,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -589,22 +589,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
return [] return []
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -617,8 +601,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -617,8 +601,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
vision_embeddings) input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -233,6 +233,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -233,6 +233,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings # We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in # to be calculated. However, due to the mandatory token ids in
......
...@@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder ...@@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsQuant) SupportsMultiModal, SupportsPP, SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
flatten_bn, make_empty_intermediate_tensors_factory, flatten_bn, make_empty_intermediate_tensors_factory,
maybe_prefix) maybe_prefix)
...@@ -797,6 +797,9 @@ class TransformersForCausalLM(TransformersBase): ...@@ -797,6 +797,9 @@ class TransformersForCausalLM(TransformersBase):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings()(input_ids)
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -873,13 +876,19 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): ...@@ -873,13 +876,19 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
multimodal_embeds = self.get_multimodal_embeddings(**kwargs) multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
if multimodal_embeds is not None: if multimodal_embeds is not None:
inputs_embeds = self.get_input_embeddings( inputs_embeds = self.get_input_embeddings(
input_ids, multimodal_embeds) input_ids,
multimodal_embeds,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None input_ids = None
model_output = super().forward(input_ids, positions, model_output = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return model_output return model_output
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self, **kwargs): def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
pixel_values = pixel_values if pixel_values is not None else kwargs.pop( pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
...@@ -934,15 +943,42 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): ...@@ -934,15 +943,42 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings=None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids) """
if (multimodal_embeddings is not None Apply token embeddings to `input_ids`.
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id) If `multimodal_embeddings` is passed, scatter them into
mask = mask.unsqueeze(-1).expand_as(inputs_embeds) `input_ids` according to the mask `is_multimodal`.
multimodal_embeddings = torch.cat(multimodal_embeddings)
In case the multi-modal token IDs exceed the vocabulary size of
inputs_embeds = inputs_embeds.masked_scatter( the language model, you can set `handle_oov_mm_token=False`
mask, multimodal_embeddings) to avoid calling the language model's `get_input_embeddings` method
return inputs_embeds on those tokens.
"""
from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.model.get_input_embeddings(),
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
...@@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix)
merge_multimodal_embeddings)
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16 _MAX_ENCODER_BATCH_SIZE = 16
...@@ -555,19 +554,21 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -555,19 +554,21 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# The audio token index is not included in the embedding table # This is to satisfy the type checker for each overload
# We need to remove it before embedding lookup if multimodal_embeddings is None or is_multimodal is None:
safe_input_ids = input_ids.clone() return super().get_input_embeddings(input_ids)
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
inputs_embeds = self.language_model.get_input_embeddings( return super().get_input_embeddings(
safe_input_ids) input_ids,
if multimodal_embeddings is not None and len( multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings) > 0: is_multimodal=is_multimodal,
inputs_embeds = merge_multimodal_embeddings( handle_oov_mm_token=handle_oov_mm_token,
input_ids, inputs_embeds, multimodal_embeddings, )
self.config.audio_token_index)
return inputs_embeds
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -601,8 +602,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -601,8 +602,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
elif inputs_embeds is None: elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
multimodal_embeddings) input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None input_ids = None
language_model = self.language_model language_model = self.language_model
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import itertools import itertools
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload from typing import Any, Literal, Optional, Protocol, Union, overload
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: ...@@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
def _merge_multimodal_embeddings( def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors, multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
...@@ -402,61 +402,35 @@ def _merge_multimodal_embeddings( ...@@ -402,61 +402,35 @@ def _merge_multimodal_embeddings(
Note: Note:
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
flattened = _flatten_embeddings(multimodal_embeddings) if len(multimodal_embeddings) == 0:
return inputs_embeds
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try: try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened. # For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
flattened.to(dtype=inputs_embeds.dtype)) mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e: except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item() num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens: if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings) expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} " f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders" f"multimodal tokens to {num_expected_tokens} placeholders"
) from e ) from e
else:
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
multimodal_embeds: NestedTensors,
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running raise ValueError("Error during masked scatter operation") from e
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds return inputs_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
def merge_multimodal_embeddings( def merge_multimodal_embeddings(
...@@ -491,23 +465,29 @@ def merge_multimodal_embeddings( ...@@ -491,23 +465,29 @@ def merge_multimodal_embeddings(
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
if isinstance(placeholder_token_id, list): if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor( is_multimodal = isin_list(input_ids, placeholder_token_id)
placeholder_token_id, else:
pin_memory=is_pin_memory_available()).to(device=input_ids.device, is_multimodal = (input_ids == placeholder_token_id)
non_blocking=True)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds, inputs_embeds,
(input_ids == placeholder_token_id), multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings, is_multimodal=is_multimodal,
) )
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(
test_elements_list,
pin_memory=is_pin_memory_available(),
).to(device=elements.device, non_blocking=True)
return torch.isin(elements, test_elements)
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: def __call__(self, prefix: str) -> torch.nn.Module:
......
...@@ -45,10 +45,8 @@ from vllm.sequence import IntermediateTensors ...@@ -45,10 +45,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config) cached_tokenizer_from_config)
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
SupportsMultiModal, SupportsTranscription) from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -376,9 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -376,9 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
audio_encoder = self.tokenizer.instruct.audio_encoder
audio_tok_id = audio_encoder.audio_token
audio_embeddings = self.get_multimodal_embeddings(**kwargs) audio_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(
audio_embeddings) input_ids,
audio_embeddings,
is_multimodal=input_ids == audio_tok_id,
)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
...@@ -421,20 +424,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -421,20 +424,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return audio_embeddings return audio_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
audio_encoder = self.tokenizer.instruct.audio_encoder
audio_tok_id = audio_encoder.audio_token
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id)
return inputs_embeds
def _parse_and_validate_audio_arrays( def _parse_and_validate_audio_arrays(
self, **kwargs: object) -> Union[list[torch.Tensor], None]: self, **kwargs: object) -> Union[list[torch.Tensor], None]:
audio_arrays = kwargs.pop("audio_arrays", None) audio_arrays = kwargs.pop("audio_arrays", None)
......
...@@ -579,10 +579,7 @@ class WhisperDecoder(nn.Module): ...@@ -579,10 +579,7 @@ class WhisperDecoder(nn.Module):
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
def get_input_embeddings( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -916,7 +913,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -916,7 +913,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since # This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. # Whisper does not have encoder text tokens.
......
...@@ -18,6 +18,7 @@ from vllm.logger import init_logger ...@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
...@@ -64,8 +65,10 @@ class EagleProposer: ...@@ -64,8 +65,10 @@ class EagleProposer:
# hidden size (e.g., Llama 3.3 70B). # hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \ # Multi-modal data support
.is_multimodal_model self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
vllm_config.model_config)
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
...@@ -175,7 +178,8 @@ class EagleProposer: ...@@ -175,7 +178,8 @@ class EagleProposer:
last_token_indices: Optional[torch.Tensor], last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None, mm_embed_inputs: Optional[tuple[list[torch.Tensor],
torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
...@@ -219,18 +223,21 @@ class EagleProposer: ...@@ -219,18 +223,21 @@ class EagleProposer:
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions) self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens] if self.supports_mm_inputs:
inputs_embeds = self.model.get_input_embeddings( mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
input_ids,
multimodal_embeddings=mm_embeds or None, self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
) )
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else: else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -372,14 +379,15 @@ class EagleProposer: ...@@ -372,14 +379,15 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions) self._set_positions(batch_size, clamped_positions)
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model: if self.supports_mm_inputs:
inputs_embeds = self.model.get_input_embeddings(input_ids) self.inputs_embeds[:batch_size] = \
self.inputs_embeds[:batch_size] = inputs_embeds self.model.get_input_embeddings(input_ids)
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:input_batch_size]
else: else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size] input_ids = self.input_ids[:input_batch_size]
inputs_embeds = None
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
...@@ -849,7 +857,7 @@ class EagleProposer: ...@@ -849,7 +857,7 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names)
if self.is_multimodal_model: if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use # Even if the target model is multimodal, we can also use
# text-only draft models # text-only draft models
try: try:
...@@ -861,7 +869,7 @@ class EagleProposer: ...@@ -861,7 +869,7 @@ class EagleProposer:
logger.warning( logger.warning(
"Draft model does not support multimodal inputs, " "Draft model does not support multimodal inputs, "
"falling back to text-only mode") "falling back to text-only mode")
self.is_multimodal_model = False self.supports_mm_inputs = False
if supports_multimodal(target_model): if supports_multimodal(target_model):
# handle multimodality # handle multimodality
...@@ -933,7 +941,7 @@ class EagleProposer: ...@@ -933,7 +941,7 @@ class EagleProposer:
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
if self.is_multimodal_model: if self.supports_mm_inputs:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
else: else:
......
...@@ -368,6 +368,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -368,6 +368,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64) dtype=torch.int64)
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope: if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy # NOTE: `mrope_positions` is implemented with one additional dummy
...@@ -1627,9 +1632,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1627,9 +1632,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0, shift_computed_tokens: int = 0,
) -> list[torch.Tensor]: ) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
mm_embeds = list[torch.Tensor]()
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
should_sync_mrope_positions = False should_sync_mrope_positions = False
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = [] mm_embeds_req: list[torch.Tensor] = []
...@@ -1638,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1638,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = \ num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features: for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position pos_info = mm_feature.mm_position
start_pos = pos_info.offset start_pos = pos_info.offset
...@@ -1670,6 +1683,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1670,6 +1683,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if (is_embed := pos_info.is_embed) is not None: if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx] is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True if is_embed is None else is_embed
mm_embeds_item = gather_mm_placeholders( mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx], encoder_output[start_idx:end_idx],
is_embed=is_embed, is_embed=is_embed,
...@@ -1677,6 +1694,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1677,6 +1694,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req.append(mm_embeds_item) mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope: if self.is_multimodal_pruning_enabled and self.uses_mrope:
assert req_state.mrope_positions is not None
should_sync_mrope_positions = True should_sync_mrope_positions = True
mm_embeds_req, new_mrope_positions, new_delta = ( mm_embeds_req, new_mrope_positions, new_delta = (
self.model.recompute_mrope_positions( self.model.recompute_mrope_positions(
...@@ -1685,18 +1703,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1685,18 +1703,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mrope_positions=req_state.mrope_positions, mrope_positions=req_state.mrope_positions,
num_computed_tokens=req_state.num_computed_tokens, num_computed_tokens=req_state.num_computed_tokens,
)) ))
assert req_state.mrope_positions is not None
req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_positions.copy_(new_mrope_positions)
req_state.mrope_position_delta = new_delta req_state.mrope_position_delta = new_delta
mm_embeds.extend(mm_embeds_req) mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions: if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu( self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
scheduler_output.total_num_scheduled_tokens)
return mm_embeds return mm_embeds, is_mm_embed
def _extract_encoder_inputs( def _extract_encoder_inputs(
self, self,
...@@ -1990,14 +2009,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1990,14 +2009,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and not self.model_config.is_encoder_decoder): and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
inputs_embeds_scheduled = self.model.get_input_embeddings( inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids.gpu[:num_scheduled_tokens], self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None, multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
) )
# TODO(woosuk): Avoid the copy. Optimize. # TODO(woosuk): Avoid the copy. Optimize.
...@@ -2586,10 +2607,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2586,10 +2607,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1) [h[token_indices] for h in aux_hidden_states], dim=-1)
else: else:
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.supports_mm_inputs: if self.supports_mm_inputs:
mm_embeds = self._gather_mm_embeddings(scheduler_output, mm_embed_inputs = self._gather_mm_embeddings(
shift_computed_tokens=1) scheduler_output,
shift_computed_tokens=1,
)
else:
mm_embed_inputs = None
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
...@@ -2599,8 +2624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2599,8 +2624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_token_indices=token_indices_to_sample, last_token_indices=token_indices_to_sample,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds, mm_embed_inputs=mm_embed_inputs,
) )
return draft_token_ids return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None: def update_config(self, overrides: dict[str, Any]) -> None:
......
...@@ -263,6 +263,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -263,6 +263,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory)
# Range tensor with values [0 .. self.max_num_tokens - 1]. # Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens # Used to initialize positions / context_lens / seq_lens
# Keep in int64 to avoid overflow with long context # Keep in int64 to avoid overflow with long context
...@@ -879,13 +886,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -879,13 +886,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings( def _gather_mm_embeddings(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]: ) -> tuple[list[torch.Tensor], torch.Tensor]:
mm_embeds: list[torch.Tensor] = [] total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
padded_total_num_scheduled_tokens = _get_padded_token_len(
self.num_tokens_paddings, total_num_scheduled_tokens)
is_mm_embed = self.is_mm_embed_cpu
is_mm_embed[:padded_total_num_scheduled_tokens] = False
mm_embeds = list[torch.Tensor]()
req_start_idx = 0
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens num_computed_tokens = req_state.num_computed_tokens
# TODO unroll loop and assume/enforce --disable_chunked_mm_input # TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as # NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid # we assume to only have whole mm items to process. Hence we avoid
...@@ -906,26 +922,53 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -906,26 +922,53 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The encoder output is already processed and stored # The encoder output is already processed and stored
# in the decoder's KV cache. # in the decoder's KV cache.
continue continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens,
)
assert start_idx < end_idx
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None) encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\ assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}." f"Encoder cache miss for {mm_hash}."
assert pos_info.is_embed is None, "Expected all positions to"\ assert pos_info.is_embed is None, "Expected all positions to"\
" be contiguous and embeddings." " be contiguous and embeddings."
encoder_output = self.encoder_cache[mm_hash]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True
# Only whole mm items are processed
mm_embeds.append(encoder_output) mm_embeds.append(encoder_output)
return mm_embeds
def _get_model_inputs(self, input_ids: torch.Tensor, req_start_idx += num_scheduled_tokens
mm_embeds: list[torch.Tensor]):
is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \
.to(self.device)
return mm_embeds, is_mm_embed
def _get_model_inputs(
self,
input_ids: torch.Tensor,
mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]],
):
if self.supports_mm_inputs: if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
inputs_embeds = self.model.get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids, input_ids,
multimodal_embeddings=mm_embeds, multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
) )
return None, inputs_embeds return None, inputs_embeds
else: else:
# For text-only models, we use token ids as input. # For text-only models, we use token ids as input.
...@@ -953,9 +996,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -953,9 +996,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.supports_mm_inputs: if self.supports_mm_inputs:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
else: else:
mm_embeds = [] mm_embed_inputs = None
torch_xla.sync(wait=False) torch_xla.sync(wait=False)
# Prepare inputs, the requests might be split into multiple # Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution. # executions, combine the result of each execution.
...@@ -972,7 +1016,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -972,7 +1016,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
end_index = self._prepare_inputs(scheduler_output, start_index) end_index = self._prepare_inputs(scheduler_output, start_index)
input_ids, inputs_embeds = self._get_model_inputs( input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embeds) self.input_ids, mm_embed_inputs)
torch_xla.sync(wait=False) torch_xla.sync(wait=False)
# Run the decoder # Run the decoder
with set_forward_context( with set_forward_context(
...@@ -1325,9 +1369,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1325,9 +1369,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hf_config.image_token_index hf_config.image_token_index
placeholders_ids = placeholders_ids.to(self.device) placeholders_ids = placeholders_ids.to(self.device)
mm_mask = torch.tensor([False] * num_tokens)
mm_mask[:items_size] = True
mm_mask = mm_mask.to(self.device)
# Assign outputs or the graph will be cut short. # Assign outputs or the graph will be cut short.
a, b = self._get_model_inputs(placeholders_ids, a, b = self._get_model_inputs(
[mm_embeds]) placeholders_ids,
mm_embed_inputs=([mm_embeds], mm_mask),
)
assert a is None assert a is None
torch_xla.sync(wait=False) torch_xla.sync(wait=False)
...@@ -1338,7 +1388,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1338,7 +1388,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32, dtype=torch.int32,
device="cpu") device="cpu")
placeholders_ids = placeholders_ids.to(self.device) placeholders_ids = placeholders_ids.to(self.device)
a, b = self._get_model_inputs(placeholders_ids, []) a, b = self._get_model_inputs(
placeholders_ids,
mm_embed_inputs=None,
)
assert a is None assert a is None
torch_xla.sync(wait=False) torch_xla.sync(wait=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment