Unverified Commit 28d00482 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fx support for multiple model architectures (#17393)

* Support for Bart and LayoutLM, and partial support for XLNet

* Support for mbart

* A lot of new models supported

* Support for other models

* LayoutLM fix

* Use strings instead of classes
parent 04681c1d
......@@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -114,7 +114,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class BartLearnedPositionalEmbedding(nn.Embedding):
......@@ -911,7 +911,7 @@ class BartDecoder(BartPretrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -2112,7 +2112,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -105,7 +105,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
......@@ -850,7 +850,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
......@@ -846,7 +846,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# contrastive loss function, adapted from
......@@ -674,7 +674,7 @@ class CLIPTextTransformer(nn.Module):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(float("-inf"))
mask.fill_(torch.tensor(float("-inf")))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
......@@ -1042,8 +1042,8 @@ class CLIPModel(CLIPPreTrainedModel):
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
......
......@@ -800,7 +800,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if bbox is None:
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
......
......@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
......@@ -998,7 +998,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
......@@ -856,7 +856,7 @@ class MarianDecoder(MarianPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
......@@ -909,7 +909,7 @@ class MBartDecoder(MBartPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -82,7 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class OPTLearnedPositionalEmbedding(nn.Embedding):
......@@ -513,7 +513,7 @@ class OPTDecoder(OPTPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
......@@ -876,7 +876,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
......@@ -883,7 +883,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class Conv1dSubsampler(nn.Module):
......@@ -888,7 +888,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
......@@ -495,7 +495,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
......@@ -524,7 +524,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
......@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
......@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
......@@ -577,7 +577,7 @@ class XGLMModel(XGLMPreTrainedModel):
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......@@ -712,7 +712,7 @@ class XGLMModel(XGLMPreTrainedModel):
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
......
......@@ -1056,7 +1056,6 @@ class XLNetModel(XLNetPreTrainedModel):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(self.device)
return pos_emb
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......@@ -1206,6 +1205,7 @@ class XLNetModel(XLNetPreTrainedModel):
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = pos_emb.to(output_h.device)
pos_emb = self.dropout(pos_emb)
# Prepare head mask if needed
......
......@@ -29,27 +29,23 @@ from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx.proxy import ParameterProxy
from .. import (
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
GPT2DoubleHeadsModel,
PretrainedConfig,
PreTrainedModel,
XLNetForQuestionAnswering,
logging,
)
from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata
......@@ -57,25 +53,25 @@ from ..utils.versions import importlib_metadata
logger = logging.get_logger(__name__)
def _generate_supported_model_classes(
def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[Type[PreTrainedModel]]:
) -> List[str]:
model_config_class = CONFIG_MAPPING[model_name]
task_mapping = {
"default": MODEL_MAPPING,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING,
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
"default": MODEL_MAPPING_NAMES,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
"speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
}
if supported_tasks is None:
......@@ -83,55 +79,78 @@ def _generate_supported_model_classes(
if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks]
model_classes = []
model_class_names = []
for task in supported_tasks:
model_class = task_mapping[task].get(model_config_class, None)
if model_class:
model_classes.append(model_class)
class_name = task_mapping[task].get(model_name, None)
if class_name:
model_class_names.append(class_name)
return model_classes
return model_class_names
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"albert",
"bart",
"bert",
"blenderbot",
"blenderbot-small",
"clip",
"distilbert",
"mobilebert",
"electra",
"megatron-bert",
"gpt2",
"gptj",
"gpt_neo",
"t5",
"gptj",
"layoutlm",
"m2m_100",
"marian",
"mbart",
"megatron-bert",
"mobilebert",
"mt5",
"opt",
"pegasus",
"plbart",
"roberta",
"vit",
"speech_to_text",
"speech_to_text_2",
"swin",
"t5",
"trocr",
"vit",
"xglm",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet",
]
_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
else:
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
_SPECIAL_SUPPORTED_MODELS = [
GPT2DoubleHeadsModel,
"CLIPTextModel",
"CLIPVisionModel",
"GPT2DoubleHeadsModel",
"Speech2Text2Decoder",
"TrOCRDecoder",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
_SUPPORTED_MODELS = tuple(
sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__)
)
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def torch_nn_functional_embedding(
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
def torch_nn_layernorm(self, input):
return input
......@@ -176,6 +195,12 @@ def torch_arange(*args, **kwargs):
start, end = args
else:
start, end, step = args
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
start = int(end)
if isinstance(step, float):
step = int(step)
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
......@@ -265,6 +290,14 @@ def torch_matmul(input, other, *, out=None):
return torch.empty(*shape, device="meta")
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")
def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
......@@ -285,13 +318,39 @@ def torch_index_select(input, dim, index, *, out=None):
def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index)
return torch_index_select(self, dim, index)
def torch_roll(input, shifts, dims=None):
return input
def torch_flip(input, dims):
return input
def torch_tensor_flip(self, dims):
return self
def torch_nn_conv1d(self, input):
l_in = input.shape[-1]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
l_out = math.floor(
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
shape[-1] = l_out
shape[-2] = self.out_channels
return torch.empty(shape, device="meta")
def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
......@@ -325,6 +384,21 @@ def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim)
def torch_unique_consecutive(input, **kwargs):
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
if isinstance(output, torch.Tensor):
return output.to("meta")
else:
return tuple(map(output, lambda x: x.to("meta")))
def torch_nn_functional_one_hot(tensor, num_classes=-1):
if num_classes < 0:
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
shape = list(tensor.shape) + [num_classes]
return torch.empty(shape, device="meta")
def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
......@@ -350,14 +424,27 @@ def torch_nn_bcewithlogitsloss(self, input, target):
def operator_getitem(a, b):
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu,
......@@ -372,15 +459,20 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.mul: torch_mul,
torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul,
torch.bmm: torch_bmm,
torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.roll: torch_roll,
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.flip: torch_flip,
torch.Tensor.flip: torch_tensor_flip,
torch.index_select: torch_index_select,
torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d,
torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
......@@ -513,7 +605,7 @@ class HFTracer(Tracer):
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
......@@ -532,22 +624,22 @@ class HFTracer(Tracer):
"""Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# from pickle, or from the "__class__" attribute in the general case.
model_class = getattr(model, "class_for_deserialization", model.__class__)
model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
device = model.device
inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
XLNetForQuestionAnswering,
elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
"XLNetForQuestionAnswering",
]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
raise ValueError(
"Could not retrieve the problem type for the sequence classification task, please set "
......@@ -571,32 +663,49 @@ class HFTracer(Tracer):
)
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
elif model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
GPT2DoubleHeadsModel,
elif model_class_name in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
"GPT2DoubleHeadsModel",
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class} not supported yet.")
raise NotImplementedError(f"{model_class_name} not supported yet.")
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = model.config.image_size
image_size = getattr(model.config, "image_size", None)
if image_size is None:
if hasattr(model.config, "vision_config"):
image_size = model.config.vision_config.image_size
elif hasattr(model.config, "encoder"):
image_size = model.config.encoder.image_size
else:
raise AttributeError('Could not find the "image_size" field in the model config')
# If no num_channels is in the config, use some arbitrary value.
num_channels = getattr(model.config, "num_channels", 3)
if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size)
height, width = image_size
inputs_dict[input_name] = torch.zeros(
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
batch_size, num_channels, height, width, dtype=torch.float32, device=device
)
elif "bbox" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
elif "input_features" in input_name:
inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
)
elif "inputs" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else:
......@@ -628,6 +737,8 @@ class HFTracer(Tracer):
if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
......@@ -731,7 +842,7 @@ class HFTracer(Tracer):
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
if root.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
......@@ -870,11 +981,22 @@ def symbolic_trace(
if input_names is None:
input_names = model.dummy_inputs.keys()
input_names = list(input_names)
sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
if not isinstance(model, _SUPPORTED_MODELS):
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
if model.__class__.__name__ not in _SUPPORTED_MODELS:
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
......
......@@ -413,6 +413,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......@@ -1386,6 +1387,7 @@ class BartStandaloneDecoderModelTester:
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
fx_comptatible = True
test_pruning = False
is_encoder_decoder = False
......
......@@ -218,6 +218,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment