Unverified Commit 50c746ee authored by Gunjan Chhablani's avatar Gunjan Chhablani Committed by GitHub
Browse files

Allow only textual inputs to VisualBert (#13687)

parent 93624bfe
......@@ -778,29 +778,30 @@ class VisualBertModel(VisualBertPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if visual_embeds is None:
raise ValueError(
f"`visual_embeds` can not be of type {type(visual_embeds)} when using a VisualBert Model."
)
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
visual_input_shape = visual_embeds.size()[:-1]
if visual_embeds is not None:
visual_input_shape = visual_embeds.size()[:-1]
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if visual_attention_mask is None:
if visual_embeds is not None and visual_attention_mask is None:
visual_attention_mask = torch.ones(visual_input_shape, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
)
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
)
else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, [batch_size, input_shape], device
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment