Unverified Commit 7f730085 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Handle image_embeds in ViltModel (#16696)



* update

* batch_size -> text_batch_size
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 161c0a2e
...@@ -704,7 +704,7 @@ VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING = r""" ...@@ -704,7 +704,7 @@ VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING = r"""
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix. model's internal embedding lookup matrix.
image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*): image_embeds (`torch.FloatTensor` of shape `(batch_size, num_images, num_patches, hidden_size)`, *optional*):
Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `pixel_values` into patch embeddings. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
...@@ -805,18 +805,22 @@ class ViltModel(ViltPreTrainedModel): ...@@ -805,18 +805,22 @@ class ViltModel(ViltPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape text_batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device) attention_mask = torch.ones(((text_batch_size, seq_length)), device=device)
if pixel_values is None: if pixel_values is not None and image_embeds is not None:
raise ValueError("You have to specify pixel_values") raise ValueError("You cannot specify both pixel_values and image_embeds at the same time")
elif pixel_values is None and image_embeds is None:
raise ValueError("You have to specify either pixel_values or image_embeds")
batch_size, num_channels, height, width = pixel_values.shape image_batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeds.shape[0]
if image_batch_size != text_batch_size:
raise ValueError("The text inputs and image inputs need to have the same batch size")
if pixel_mask is None: if pixel_mask is None:
pixel_mask = torch.ones(((batch_size, height, width)), device=device) pixel_mask = torch.ones((image_batch_size, self.config.image_size, self.config.image_size), device=device)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -1338,11 +1342,17 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel): ...@@ -1338,11 +1342,17 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values.ndim == 4: if pixel_values is not None and pixel_values.ndim == 4:
# add dummy num_images dimension # add dummy num_images dimension
pixel_values = pixel_values.unsqueeze(1) pixel_values = pixel_values.unsqueeze(1)
num_images = pixel_values.shape[1] if image_embeds is not None and image_embeds.ndim == 3:
# add dummy num_images dimension
image_embeds = image_embeds.unsqueeze(1)
num_images = pixel_values.shape[1] if pixel_values is not None else None
if num_images is None:
num_images = image_embeds.shape[1] if image_embeds is not None else None
if num_images != self.config.num_images: if num_images != self.config.num_images:
raise ValueError( raise ValueError(
"Make sure to match the number of images in the model with the number of images in the input." "Make sure to match the number of images in the model with the number of images in the input."
...@@ -1356,11 +1366,11 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel): ...@@ -1356,11 +1366,11 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
pixel_values=pixel_values[:, i, :, :, :], pixel_values=pixel_values[:, i, :, :, :] if pixel_values is not None else None,
pixel_mask=pixel_mask[:, i, :, :] if pixel_mask is not None else None, pixel_mask=pixel_mask[:, i, :, :] if pixel_mask is not None else None,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
image_embeds=image_embeds, image_embeds=image_embeds[:, i, :, :] if image_embeds is not None else None,
image_token_type_idx=i + 1, image_token_type_idx=i + 1,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment