Unverified Commit 50c746ee authored by Gunjan Chhablani's avatar Gunjan Chhablani Committed by GitHub
Browse files

Allow only textual inputs to VisualBert (#13687)

parent 93624bfe
...@@ -778,30 +778,31 @@ class VisualBertModel(VisualBertPreTrainedModel): ...@@ -778,30 +778,31 @@ class VisualBertModel(VisualBertPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if visual_embeds is None:
raise ValueError(
f"`visual_embeds` can not be of type {type(visual_embeds)} when using a VisualBert Model."
)
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if visual_embeds is not None:
visual_input_shape = visual_embeds.size()[:-1] visual_input_shape = visual_embeds.size()[:-1]
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
if visual_attention_mask is None: if visual_embeds is not None and visual_attention_mask is None:
visual_attention_mask = torch.ones(visual_input_shape, device=device) visual_attention_mask = torch.ones(visual_input_shape, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1) combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
) )
else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, [batch_size, input_shape], device
)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment