Unverified Commit 28d00482 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fx support for multiple model architectures (#17393)

* Support for Bart and LayoutLM, and partial support for XLNet

* Support for mbart

* A lot of new models supported

* Support for other models

* LayoutLM fix

* Use strings instead of classes
parent 04681c1d
...@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -114,7 +114,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -114,7 +114,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class BartLearnedPositionalEmbedding(nn.Embedding): class BartLearnedPositionalEmbedding(nn.Embedding):
...@@ -911,7 +911,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -911,7 +911,7 @@ class BartDecoder(BartPretrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -2112,7 +2112,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2112,7 +2112,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -105,7 +105,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -105,7 +105,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class BlenderbotLearnedPositionalEmbedding(nn.Embedding): class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
...@@ -850,7 +850,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -850,7 +850,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall # Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
...@@ -846,7 +846,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -846,7 +846,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# contrastive loss function, adapted from # contrastive loss function, adapted from
...@@ -674,7 +674,7 @@ class CLIPTextTransformer(nn.Module): ...@@ -674,7 +674,7 @@ class CLIPTextTransformer(nn.Module):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len) mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(float("-inf")) mask.fill_(torch.tensor(float("-inf")))
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
...@@ -1042,8 +1042,8 @@ class CLIPModel(CLIPPreTrainedModel): ...@@ -1042,8 +1042,8 @@ class CLIPModel(CLIPPreTrainedModel):
text_embeds = self.text_projection(text_embeds) text_embeds = self.text_projection(text_embeds)
# normalized features # normalized features
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits # cosine similarity as logits
logit_scale = self.logit_scale.exp() logit_scale = self.logit_scale.exp()
......
...@@ -800,7 +800,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel): ...@@ -800,7 +800,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if bbox is None: if bbox is None:
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
......
...@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
...@@ -998,7 +998,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -998,7 +998,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None and combined_attention_mask is not None: if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class MarianSinusoidalPositionalEmbedding(nn.Embedding): class MarianSinusoidalPositionalEmbedding(nn.Embedding):
...@@ -856,7 +856,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -856,7 +856,7 @@ class MarianDecoder(MarianPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
...@@ -909,7 +909,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -909,7 +909,7 @@ class MBartDecoder(MBartPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -82,7 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -82,7 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
...@@ -513,7 +513,7 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -513,7 +513,7 @@ class OPTDecoder(OPTPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
...@@ -876,7 +876,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -876,7 +876,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
...@@ -883,7 +883,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -883,7 +883,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class Conv1dSubsampler(nn.Module): class Conv1dSubsampler(nn.Module):
...@@ -888,7 +888,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -888,7 +888,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2 # Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
...@@ -495,7 +495,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): ...@@ -495,7 +495,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
...@@ -524,7 +524,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -524,7 +524,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ ...@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf")) mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1)) mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
...@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
...@@ -577,7 +577,7 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -577,7 +577,7 @@ class XGLMModel(XGLMPreTrainedModel):
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
...@@ -712,7 +712,7 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -712,7 +712,7 @@ class XGLMModel(XGLMPreTrainedModel):
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
......
...@@ -1056,7 +1056,6 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1056,7 +1056,6 @@ class XLNetModel(XLNetPreTrainedModel):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(self.device)
return pos_emb return pos_emb
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -1206,6 +1205,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1206,6 +1205,7 @@ class XLNetModel(XLNetPreTrainedModel):
# Positional encoding # Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = pos_emb.to(output_h.device)
pos_emb = self.dropout(pos_emb) pos_emb = self.dropout(pos_emb)
# Prepare head mask if needed # Prepare head mask if needed
......
...@@ -29,27 +29,23 @@ from torch import nn ...@@ -29,27 +29,23 @@ from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx.proxy import ParameterProxy from torch.fx.proxy import ParameterProxy
from .. import ( from .. import PretrainedConfig, PreTrainedModel, logging
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
GPT2DoubleHeadsModel,
PretrainedConfig,
PreTrainedModel,
XLNetForQuestionAnswering,
logging,
)
from ..models.auto import get_values from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata from ..utils.versions import importlib_metadata
...@@ -57,25 +53,25 @@ from ..utils.versions import importlib_metadata ...@@ -57,25 +53,25 @@ from ..utils.versions import importlib_metadata
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _generate_supported_model_classes( def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig], model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None, supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[Type[PreTrainedModel]]: ) -> List[str]:
model_config_class = CONFIG_MAPPING[model_name]
task_mapping = { task_mapping = {
"default": MODEL_MAPPING, "default": MODEL_MAPPING_NAMES,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING, "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING, "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING, "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING, "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING, "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
} }
if supported_tasks is None: if supported_tasks is None:
...@@ -83,55 +79,78 @@ def _generate_supported_model_classes( ...@@ -83,55 +79,78 @@ def _generate_supported_model_classes(
if isinstance(supported_tasks, str): if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks] supported_tasks = [supported_tasks]
model_classes = [] model_class_names = []
for task in supported_tasks: for task in supported_tasks:
model_class = task_mapping[task].get(model_config_class, None) class_name = task_mapping[task].get(model_name, None)
if model_class: if class_name:
model_classes.append(model_class) model_class_names.append(class_name)
return model_classes return model_class_names
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"albert", "albert",
"bart",
"bert", "bert",
"blenderbot",
"blenderbot-small",
"clip",
"distilbert", "distilbert",
"mobilebert",
"electra", "electra",
"megatron-bert",
"gpt2", "gpt2",
"gptj",
"gpt_neo", "gpt_neo",
"t5", "gptj",
"layoutlm",
"m2m_100",
"marian",
"mbart",
"megatron-bert",
"mobilebert",
"mt5",
"opt",
"pegasus",
"plbart",
"roberta", "roberta",
"vit", "speech_to_text",
"speech_to_text_2",
"swin", "swin",
"t5",
"trocr",
"vit",
"xglm",
# TODO: add support for them as it should be quite easy to do so (small blocking issues). # TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet", # "xlnet",
] ]
_REGULAR_SUPPORTED_MODELS = [] _REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict): if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item)) _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
else: else:
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item)) _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
_SPECIAL_SUPPORTED_MODELS = [ _SPECIAL_SUPPORTED_MODELS = [
GPT2DoubleHeadsModel, "CLIPTextModel",
"CLIPVisionModel",
"GPT2DoubleHeadsModel",
"Speech2Text2Decoder",
"TrOCRDecoder",
# TODO: add support for them as it should be quite easy to do so (small blocking issues). # TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering, # XLNetForQuestionAnswering,
] ]
_SUPPORTED_MODELS = tuple( _SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__)
)
def torch_nn_embedding(self, input): def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta") return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def torch_nn_functional_embedding(
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
def torch_nn_layernorm(self, input): def torch_nn_layernorm(self, input):
return input return input
...@@ -176,6 +195,12 @@ def torch_arange(*args, **kwargs): ...@@ -176,6 +195,12 @@ def torch_arange(*args, **kwargs):
start, end = args start, end = args
else: else:
start, end, step = args start, end, step = args
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
start = int(end)
if isinstance(step, float):
step = int(step)
step = kwargs.get("step", step) step = kwargs.get("step", step)
dtype = kwargs.get("dtype") dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta") return torch.empty((end - start) // step, dtype=dtype, device="meta")
...@@ -265,6 +290,14 @@ def torch_matmul(input, other, *, out=None): ...@@ -265,6 +290,14 @@ def torch_matmul(input, other, *, out=None):
return torch.empty(*shape, device="meta") return torch.empty(*shape, device="meta")
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")
def torch_einsum(equation, *operands): def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard. # TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands) concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
...@@ -285,13 +318,39 @@ def torch_index_select(input, dim, index, *, out=None): ...@@ -285,13 +318,39 @@ def torch_index_select(input, dim, index, *, out=None):
def torch_tensor_index_select(self, dim, index): def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index) return torch_index_select(self, dim, index)
def torch_roll(input, shifts, dims=None): def torch_roll(input, shifts, dims=None):
return input return input
def torch_flip(input, dims):
return input
def torch_tensor_flip(self, dims):
return self
def torch_nn_conv1d(self, input):
l_in = input.shape[-1]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
l_out = math.floor(
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
shape[-1] = l_out
shape[-2] = self.out_channels
return torch.empty(shape, device="meta")
def torch_nn_conv2d(self, input): def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:] h_in, w_in = input.shape[-2:]
shape = None shape = None
...@@ -325,6 +384,21 @@ def torch_tensor_unsqueeze(self, dim): ...@@ -325,6 +384,21 @@ def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim) return torch_unsqueeze(self, dim)
def torch_unique_consecutive(input, **kwargs):
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
if isinstance(output, torch.Tensor):
return output.to("meta")
else:
return tuple(map(output, lambda x: x.to("meta")))
def torch_nn_functional_one_hot(tensor, num_classes=-1):
if num_classes < 0:
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
shape = list(tensor.shape) + [num_classes]
return torch.empty(shape, device="meta")
def torch_nn_mseloss(self, input, target): def torch_nn_mseloss(self, input, target):
if self.reduction == "none": if self.reduction == "none":
shape = target.shape shape = target.shape
...@@ -350,14 +424,27 @@ def torch_nn_bcewithlogitsloss(self, input, target): ...@@ -350,14 +424,27 @@ def torch_nn_bcewithlogitsloss(self, input, target):
def operator_getitem(a, b): def operator_getitem(a, b):
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
if isinstance(a, torch.Tensor): if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation. # TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b) return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: torch_nn_embedding, torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm, torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.Linear: torch_nn_linear, torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu, torch.relu: torch_relu,
...@@ -372,15 +459,20 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { ...@@ -372,15 +459,20 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.mul: torch_mul, torch.mul: torch_mul,
torch.Tensor.mul: torch_tensor_mul, torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul, torch.matmul: torch_matmul,
torch.bmm: torch_bmm,
torch.einsum: torch_einsum, torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat, torch.Tensor.repeat: torch_tensor_repeat,
torch.roll: torch_roll, torch.roll: torch_roll,
# TODO: those might not be needed. torch.flip: torch_flip,
# torch.index_select: torch_index_select, torch.Tensor.flip: torch_tensor_flip,
# torch.Tensor.index_select: torch_tensor_index_select, torch.index_select: torch_index_select,
torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d, torch.nn.Conv2d: torch_nn_conv2d,
torch.unsqueeze: torch_unsqueeze, torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze, torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.MSELoss: torch_nn_mseloss, torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
...@@ -513,7 +605,7 @@ class HFTracer(Tracer): ...@@ -513,7 +605,7 @@ class HFTracer(Tracer):
# Feature flag for proxying accesses to buffer values # Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
def __init__(self, autowrap_modules=(math,), autowrap_functions=()): def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
...@@ -532,22 +624,22 @@ class HFTracer(Tracer): ...@@ -532,22 +624,22 @@ class HFTracer(Tracer):
"""Generates dummy input for model inference recording.""" """Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# from pickle, or from the "__class__" attribute in the general case. # from pickle, or from the "__class__" attribute in the general case.
model_class = getattr(model, "class_for_deserialization", model.__class__) model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
device = model.device device = model.device
inputs_dict = {} inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]: if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0] batch_size = shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [ elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
XLNetForQuestionAnswering, "XLNetForQuestionAnswering",
]: ]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
if not hasattr(model.config, "problem_type") or model.config.problem_type is None: if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
raise ValueError( raise ValueError(
"Could not retrieve the problem type for the sequence classification task, please set " "Could not retrieve the problem type for the sequence classification task, please set "
...@@ -571,32 +663,49 @@ class HFTracer(Tracer): ...@@ -571,32 +663,49 @@ class HFTracer(Tracer):
) )
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device) inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class in [ elif model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING), *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]: ]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [ elif model_class_name in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING), *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_MASKED_LM_MAPPING), *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
GPT2DoubleHeadsModel, "GPT2DoubleHeadsModel",
]: ]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else: else:
raise NotImplementedError(f"{model_class} not supported yet.") raise NotImplementedError(f"{model_class_name} not supported yet.")
elif "pixel_values" in input_name: elif "pixel_values" in input_name:
batch_size = shape[0] batch_size = shape[0]
image_size = model.config.image_size image_size = getattr(model.config, "image_size", None)
if image_size is None:
if hasattr(model.config, "vision_config"):
image_size = model.config.vision_config.image_size
elif hasattr(model.config, "encoder"):
image_size = model.config.encoder.image_size
else:
raise AttributeError('Could not find the "image_size" field in the model config')
# If no num_channels is in the config, use some arbitrary value.
num_channels = getattr(model.config, "num_channels", 3)
if not isinstance(image_size, collections.abc.Iterable): if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size) image_size = (image_size, image_size)
height, width = image_size height, width = image_size
inputs_dict[input_name] = torch.zeros( inputs_dict[input_name] = torch.zeros(
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device batch_size, num_channels, height, width, dtype=torch.float32, device=device
) )
elif "bbox" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
elif "input_features" in input_name:
inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
)
elif "inputs" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name: elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else: else:
...@@ -628,6 +737,8 @@ class HFTracer(Tracer): ...@@ -628,6 +737,8 @@ class HFTracer(Tracer):
if kind == "call_function": if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target) meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas) meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method": elif kind == "call_method":
method = getattr(args_metas[0].__class__, target) method = getattr(args_metas[0].__class__, target)
meta_target = _MANUAL_META_OVERRIDES.get(method, method) meta_target = _MANUAL_META_OVERRIDES.get(method, method)
...@@ -731,7 +842,7 @@ class HFTracer(Tracer): ...@@ -731,7 +842,7 @@ class HFTracer(Tracer):
sequence_length = _generate_random_int() sequence_length = _generate_random_int()
shape = [batch_size, sequence_length] shape = [batch_size, sequence_length]
if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
num_choices = _generate_random_int(low=2, high=5) num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices) shape.insert(1, num_choices)
...@@ -870,11 +981,22 @@ def symbolic_trace( ...@@ -870,11 +981,22 @@ def symbolic_trace(
if input_names is None: if input_names is None:
input_names = model.dummy_inputs.keys() input_names = model.dummy_inputs.keys()
input_names = list(input_names)
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
if not isinstance(model, _SUPPORTED_MODELS): if model.__class__.__name__ not in _SUPPORTED_MODELS:
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError( raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
) )
......
...@@ -413,6 +413,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -413,6 +413,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
) )
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
...@@ -1386,6 +1387,7 @@ class BartStandaloneDecoderModelTester: ...@@ -1386,6 +1387,7 @@ class BartStandaloneDecoderModelTester:
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else () all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else () all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
fx_comptatible = True
test_pruning = False test_pruning = False
is_encoder_decoder = False is_encoder_decoder = False
......
...@@ -218,6 +218,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -218,6 +218,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else () all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment