Unverified Commit 27d7638b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 17617398
......@@ -1365,19 +1365,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
......
......@@ -49,8 +49,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
# # === Audio Inputs === #
......@@ -438,19 +437,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -467,8 +453,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
# condition is for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -1459,19 +1459,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
......
......@@ -79,7 +79,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention,
from .qwen2_vl import Qwen2VLProcessingInfo
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings)
_merge_multimodal_embeddings, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__)
......@@ -1324,17 +1325,22 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return multimodal_embeddings
def _compute_deepstack_embeds(
self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor:
visual_lens = [
x.shape[0] if isinstance(x, torch.Tensor) else len(x)
for x in multimodal_embeddings
]
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim],
dim=-1)
(
multimodal_embeddings_main,
multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)
multimodal_embeddings = torch.split(multimodal_embeddings_main,
visual_lens,
......@@ -1346,39 +1352,62 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1))
deepstack_input_embeds = merge_multimodal_embeddings(
input_ids,
deepstack_input_embeds,
multimodal_embeddings_multiscale,
placeholder_token_id=[
self.config.image_token_id, self.config.video_token_id
],
deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
return deepstack_input_embeds, multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
deepstack_input_embeds = None
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if self.use_deepstack:
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
input_ids, inputs_embeds, multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
inputs_embeds = self._get_text_embeddings(
input_ids,
self.language_model.get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
if self.use_deepstack:
if deepstack_input_embeds is None:
deepstack_input_embeds = torch.zeros_like(
inputs_embeds).unsqueeze(0).repeat(
self.deepstack_num_level, 1, 1).contiguous()
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
if deepstack_input_embeds is not None:
deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze(
0).repeat(self.deepstack_num_level, 1, 1).contiguous()
self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
......
......@@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .qwen import QWenBaseModel, QWenModel
from .utils import flatten_bn, merge_multimodal_embeddings
from .utils import flatten_bn
class QwenImagePixelInputs(TensorSchema):
......@@ -756,21 +756,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.transformer.visual.image_pad_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -786,8 +771,12 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids ==
self.transformer.visual.image_pad_id,
)
input_ids = None
hidden_states = self.transformer(input_ids, positions,
......
......@@ -218,6 +218,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.roberta.get_input_embeddings(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
......
......@@ -38,7 +38,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
IMG_START = '<img>'
IMG_END = '</img>'
......@@ -842,19 +842,24 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_context_token_id is not None
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.img_context_token_id,
)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
......@@ -873,8 +878,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.img_context_token_id,
)
input_ids = None
forward_kwargs = {
......
......@@ -483,6 +483,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
......
......@@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
......@@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
from .vision import run_dp_sharded_vision_model
......@@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1 else cur_feature[0])
return merged_image_features
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
......@@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
else:
is_text = input_ids != self.config.image_token_id
text_ids = input_ids[is_text]
text_embeds = self.language_model.model.get_input_embeddings(
text_ids)
inputs_embeds = torch.empty(input_ids.shape[0],
text_embeds.shape[-1],
dtype=text_embeds.dtype,
device=text_embeds.device)
inputs_embeds[is_text] = text_embeds
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
......@@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.language_model(input_ids,
......
......@@ -40,7 +40,7 @@ from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
maybe_prefix)
from .vision import VisionEncoderInfo, get_vision_encoder_info
......@@ -589,22 +589,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
return []
return self._process_image_input(image_input)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
......@@ -617,8 +601,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(
input_ids=input_ids,
......
......@@ -233,6 +233,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
......
......@@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
flatten_bn, make_empty_intermediate_tensors_factory,
maybe_prefix)
......@@ -797,6 +797,9 @@ class TransformersForCausalLM(TransformersBase):
else:
self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings()(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
......@@ -873,13 +876,19 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
if multimodal_embeds is not None:
inputs_embeds = self.get_input_embeddings(
input_ids, multimodal_embeds)
input_ids,
multimodal_embeds,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
model_output = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds)
return model_output
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
......@@ -934,15 +943,42 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings=None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id)
mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
multimodal_embeddings = torch.cat(multimodal_embeddings)
inputs_embeds = inputs_embeds.masked_scatter(
mask, multimodal_embeddings)
return inputs_embeds
"""
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens.
"""
from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.model.get_input_embeddings(),
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
......@@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
init_vllm_registered_model, maybe_prefix)
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16
......@@ -555,19 +554,21 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# The audio token index is not included in the embedding table
# We need to remove it before embedding lookup
safe_input_ids = input_ids.clone()
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
inputs_embeds = self.language_model.get_input_embeddings(
safe_input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().get_input_embeddings(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(self,
input_ids: torch.Tensor,
......@@ -601,8 +602,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_index,
)
input_ids = None
language_model = self.language_model
......
......@@ -4,7 +4,7 @@
import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
from typing import Any, Literal, Optional, Protocol, Union, overload
import torch
import torch.nn as nn
......@@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
......@@ -402,61 +402,35 @@ def _merge_multimodal_embeddings(
Note:
This updates ``inputs_embeds`` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
if len(multimodal_embeddings) == 0:
return inputs_embeds
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
flattened.to(dtype=inputs_embeds.dtype))
mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
multimodal_embeds: NestedTensors,
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
raise ValueError("Error during masked scatter operation") from e
merged_embeds[is_text] = text_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
return inputs_embeds
def merge_multimodal_embeddings(
......@@ -491,23 +465,29 @@ def merge_multimodal_embeddings(
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(
placeholder_token_id,
pin_memory=is_pin_memory_available()).to(device=input_ids.device,
non_blocking=True)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
is_multimodal = isin_list(input_ids, placeholder_token_id)
else:
is_multimodal = (input_ids == placeholder_token_id)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(
test_elements_list,
pin_memory=is_pin_memory_available(),
).to(device=elements.device, non_blocking=True)
return torch.isin(elements, test_elements)
class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module:
......
......@@ -45,10 +45,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsTranscription)
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
logger = init_logger(__name__)
......@@ -376,9 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
audio_encoder = self.tokenizer.instruct.audio_encoder
audio_tok_id = audio_encoder.audio_token
audio_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
audio_embeddings)
inputs_embeds = self.get_input_embeddings(
input_ids,
audio_embeddings,
is_multimodal=input_ids == audio_tok_id,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......@@ -421,20 +424,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return audio_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
audio_encoder = self.tokenizer.instruct.audio_encoder
audio_tok_id = audio_encoder.audio_token
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id)
return inputs_embeds
def _parse_and_validate_audio_arrays(
self, **kwargs: object) -> Union[list[torch.Tensor], None]:
audio_arrays = kwargs.pop("audio_arrays", None)
......
......@@ -579,10 +579,7 @@ class WhisperDecoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -916,7 +913,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens.
......
......@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
......@@ -64,8 +65,10 @@ class EagleProposer:
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
vllm_config.model_config)
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
......@@ -175,7 +178,8 @@ class EagleProposer:
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
......@@ -219,18 +223,21 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
......@@ -372,14 +379,15 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions)
self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
if self.supports_mm_inputs:
self.inputs_embeds[:batch_size] = \
self.model.get_input_embeddings(input_ids)
input_ids = None
inputs_embeds = self.inputs_embeds[:input_batch_size]
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
inputs_embeds = None
# Run the model.
with set_forward_context(per_layer_attn_metadata,
......@@ -849,7 +857,7 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names)
if self.is_multimodal_model:
if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use
# text-only draft models
try:
......@@ -861,7 +869,7 @@ class EagleProposer:
logger.warning(
"Draft model does not support multimodal inputs, "
"falling back to text-only mode")
self.is_multimodal_model = False
self.supports_mm_inputs = False
if supports_multimodal(target_model):
# handle multimodality
......@@ -933,7 +941,7 @@ class EagleProposer:
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.is_multimodal_model:
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
......
......@@ -368,6 +368,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
......@@ -1627,9 +1632,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
mm_embeds = list[torch.Tensor]()
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
should_sync_mrope_positions = False
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
......@@ -1638,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
......@@ -1670,6 +1683,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True if is_embed is None else is_embed
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
......@@ -1677,6 +1694,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
assert req_state.mrope_positions is not None
should_sync_mrope_positions = True
mm_embeds_req, new_mrope_positions, new_delta = (
self.model.recompute_mrope_positions(
......@@ -1685,18 +1703,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mrope_positions=req_state.mrope_positions,
num_computed_tokens=req_state.num_computed_tokens,
))
assert req_state.mrope_positions is not None
req_state.mrope_positions.copy_(new_mrope_positions)
req_state.mrope_position_delta = new_delta
mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(
scheduler_output.total_num_scheduled_tokens)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
return mm_embeds
return mm_embeds, is_mm_embed
def _extract_encoder_inputs(
self,
......@@ -1990,14 +2009,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None,
self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
# TODO(woosuk): Avoid the copy. Optimize.
......@@ -2586,10 +2607,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.supports_mm_inputs:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
mm_embed_inputs = self._gather_mm_embeddings(
scheduler_output,
shift_computed_tokens=1,
)
else:
mm_embed_inputs = None
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
......@@ -2599,8 +2624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_token_indices=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,
mm_embed_inputs=mm_embed_inputs,
)
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:
......
......@@ -263,6 +263,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory)
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
# Keep in int64 to avoid overflow with long context
......@@ -879,13 +886,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
padded_total_num_scheduled_tokens = _get_padded_token_len(
self.num_tokens_paddings, total_num_scheduled_tokens)
is_mm_embed = self.is_mm_embed_cpu
is_mm_embed[:padded_total_num_scheduled_tokens] = False
mm_embeds = list[torch.Tensor]()
req_start_idx = 0
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid
......@@ -906,26 +922,53 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens,
)
assert start_idx < end_idx
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
assert pos_info.is_embed is None, "Expected all positions to"\
" be contiguous and embeddings."
encoder_output = self.encoder_cache[mm_hash]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True
# Only whole mm items are processed
mm_embeds.append(encoder_output)
return mm_embeds
def _get_model_inputs(self, input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor]):
req_start_idx += num_scheduled_tokens
is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \
.to(self.device)
return mm_embeds, is_mm_embed
def _get_model_inputs(
self,
input_ids: torch.Tensor,
mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]],
):
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids,
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
return None, inputs_embeds
else:
# For text-only models, we use token ids as input.
......@@ -953,9 +996,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
mm_embed_inputs = None
torch_xla.sync(wait=False)
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
......@@ -972,7 +1016,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
end_index = self._prepare_inputs(scheduler_output, start_index)
input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embeds)
self.input_ids, mm_embed_inputs)
torch_xla.sync(wait=False)
# Run the decoder
with set_forward_context(
......@@ -1325,9 +1369,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hf_config.image_token_index
placeholders_ids = placeholders_ids.to(self.device)
mm_mask = torch.tensor([False] * num_tokens)
mm_mask[:items_size] = True
mm_mask = mm_mask.to(self.device)
# Assign outputs or the graph will be cut short.
a, b = self._get_model_inputs(placeholders_ids,
[mm_embeds])
a, b = self._get_model_inputs(
placeholders_ids,
mm_embed_inputs=([mm_embeds], mm_mask),
)
assert a is None
torch_xla.sync(wait=False)
......@@ -1338,7 +1388,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32,
device="cpu")
placeholders_ids = placeholders_ids.to(self.device)
a, b = self._get_model_inputs(placeholders_ids, [])
a, b = self._get_model_inputs(
placeholders_ids,
mm_embed_inputs=None,
)
assert a is None
torch_xla.sync(wait=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment