Unverified Commit 9c83b96e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Tests] Add Common Test for Training + Fix a couple of bugs (#8415)

* add training tests

* correct longformer

* fix docs

* fix some tests

* fix some more train tests

* remove ipdb

* fix multiple edge case model training

* fix funnel and prophetnet

* clean gpt models

* undo renaming of albert
parent 52040517
......@@ -81,6 +81,13 @@ AutoModelForMultipleChoice
:members:
AutoModelForNextSentencePrediction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForNextSentencePrediction
:members:
AutoModelForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -1801,7 +1801,7 @@ class GeneralizedRCNN(nn.Module):
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
......
......@@ -29,7 +29,7 @@ For further information or requests, please go to [BERTimbau repository](https:/
```python
from transformers import AutoTokenizer # Or BertTokenizer
from transformers import AutoModelForPretraining # Or BertForPreTraining for loading pretraining heads
from transformers import AutoModelForPreTraining # Or BertForPreTraining for loading pretraining heads
from transformers import AutoModel # or BertModel, for BERT without pretraining heads
model = AutoModelForPreTraining.from_pretrained('neuralmind/bert-base-portuguese-cased')
......
......@@ -329,6 +329,7 @@ if is_torch_available():
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
......@@ -340,6 +341,7 @@ if is_torch_available():
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
......
......@@ -77,6 +77,7 @@ from .modeling_bart import (
from .modeling_bert import (
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
......@@ -128,6 +129,7 @@ from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
from .modeling_funnel import (
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
FunnelForQuestionAnswering,
FunnelForSequenceClassification,
FunnelForTokenClassification,
......@@ -143,12 +145,13 @@ from .modeling_longformer import (
LongformerForTokenClassification,
LongformerModel,
)
from .modeling_lxmert import LxmertForPreTraining, LxmertModel
from .modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
from .modeling_marian import MarianMTModel
from .modeling_mbart import MBartForConditionalGeneration
from .modeling_mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
MobileBertForNextSentencePrediction,
MobileBertForPreTraining,
MobileBertForQuestionAnswering,
MobileBertForSequenceClassification,
......@@ -166,6 +169,7 @@ from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be i
from .modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
ReformerForSequenceClassification,
ReformerModel,
ReformerModelWithLMHead,
)
......@@ -285,6 +289,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForPreTraining),
(LxmertConfig, LxmertForPreTraining),
(FunnelConfig, FunnelForPreTraining),
]
)
......@@ -396,6 +401,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(DebertaConfig, DebertaForSequenceClassification),
(GPT2Config, GPT2ForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification),
]
)
......@@ -417,6 +423,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(ElectraConfig, ElectraForQuestionAnswering),
(ReformerConfig, ReformerForQuestionAnswering),
(FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering),
]
)
......@@ -460,6 +467,13 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
]
)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
[
(BertConfig, BertForNextSentencePrediction),
(MobileBertConfig, MobileBertForNextSentencePrediction),
]
)
AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
......@@ -1519,3 +1533,103 @@ class AutoModelForMultipleChoice:
", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
)
)
class AutoModelForNextSentencePrediction:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library---with a
multiple choice classification head---when created with the when created with the
:meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` class method or the
:meth:`~transformers.AutoModelForNextSentencePrediction.from_config` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoModelForNextSentencePrediction is designed to be instantiated "
"using the `AutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForNextSentencePrediction.from_config(config)` methods."
)
@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
def from_config(cls, config):
r"""
Instantiates one of the model classes of the library---with a multiple choice classification head---from a
configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` to load
the model weights.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
List options
Examples::
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
>>> # Download configuration from S3 and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForNextSentencePrediction.from_config(config)
"""
if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)
raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)
@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
@add_start_docstrings(
"Instantiate one of the model classes of the library---with a multiple choice classification head---from a "
"pretrained model.",
AUTO_MODEL_PRETRAINED_DOCSTRING,
)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Examples::
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
>>> # Download model and configuration from S3 and cache.
>>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
>>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
>>> model = AutoModelForNextSentencePrediction.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)
......@@ -1228,13 +1228,14 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
next_sentence_label=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
......@@ -1255,10 +1256,18 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
"""
if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
......@@ -1278,9 +1287,9 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = None
if next_sentence_label is not None:
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1))
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
......
......@@ -1069,7 +1069,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
......
......@@ -1069,7 +1069,7 @@ class LongformerEncoder(nn.Module):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return module(*inputs, is_global_attn)
return custom_forward
......@@ -1079,7 +1079,6 @@ class LongformerEncoder(nn.Module):
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
else:
layer_outputs = layer_module(
......
......@@ -17,6 +17,7 @@
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -1154,16 +1155,17 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
visual_attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
masked_lm_labels=None,
labels=None,
obj_labels=None,
matched_label=None,
ans=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
......@@ -1183,6 +1185,15 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
Returns:
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("masked_lm_labels")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
device = input_ids.device if input_ids is not None else inputs_embeds.device
lxmert_output = self.lxmert(
input_ids=input_ids,
......@@ -1210,13 +1221,13 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
total_loss = (
None
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
if (labels is None and matched_label is None and obj_labels is None and ans is None)
else torch.tensor(0.0, device=device)
)
if masked_lm_labels is not None and self.task_mask_lm:
if labels is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"](
lang_prediction_scores.view(-1, self.config.vocab_size),
masked_lm_labels.view(-1),
labels.view(-1),
)
total_loss += masked_lm_loss
if matched_label is not None and self.task_matched:
......@@ -1391,6 +1402,7 @@ class LxmertForQuestionAnswering(LxmertPreTrainedModel):
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
lxmert_output = self.lxmert(
input_ids=input_ids,
......
......@@ -1194,13 +1194,14 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
next_sentence_label=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see ``input_ids`` docstring) Indices should be in ``[0, 1]``.
......@@ -1221,10 +1222,18 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> loss = outputs.loss
>>> logits = outputs.logits
"""
if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilebert(
......@@ -1243,9 +1252,9 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
seq_relationship_score = self.cls(pooled_output)
next_sentence_loss = None
if next_sentence_label is not None:
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1))
if not return_dict:
output = (seq_relationship_score,) + outputs[2:]
......
......@@ -824,7 +824,7 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
......
......@@ -221,7 +221,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
f"Some weights of the PyTorch model were not used when "
f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model trained on another task "
f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPretraining model).\n"
f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect "
f"to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model)."
)
......@@ -375,7 +375,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
f"Some weights of the TF 2.0 model were not used when "
f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect "
f"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
)
......
......@@ -730,7 +730,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
......
......@@ -1047,7 +1047,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
......
......@@ -256,6 +256,9 @@ MODEL_FOR_MASKED_LM_MAPPING = None
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = None
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = None
MODEL_FOR_PRETRAINING_MAPPING = None
......@@ -313,6 +316,15 @@ class AutoModelForMultipleChoice:
requires_pytorch(self)
class AutoModelForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class AutoModelForPreTraining:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
......
......@@ -24,7 +24,10 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available():
import torch
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
AlbertConfig,
AlbertForMaskedLM,
AlbertForMultipleChoice,
......@@ -227,6 +230,20 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
inputs_dict["sentence_order_label"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
return inputs_dict
def setUp(self):
self.model_tester = AlbertModelTester(self)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
......
......@@ -25,7 +25,10 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
if is_torch_available():
import torch
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
......@@ -268,7 +271,7 @@ class BertModelTester:
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
next_sentence_label=sequence_labels,
labels=sequence_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
......@@ -377,6 +380,20 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
inputs_dict["next_sentence_label"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
return inputs_dict
def setUp(self):
self.model_tester = BertModelTester(self)
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
......
......@@ -35,10 +35,12 @@ if is_torch_available():
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
AdaptiveEmbedding,
BertConfig,
BertModel,
......@@ -88,7 +90,10 @@ class ModelTesterMixin:
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
elif model_class in [
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
]:
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
......@@ -204,6 +209,41 @@ class ModelTesterMixin:
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_training(self):
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"):
return
config.gradient_checkpointing = True
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
......
......@@ -38,7 +38,7 @@ class DPRModelTester:
parent,
batch_size=13,
seq_length=7,
is_training=True,
is_training=False,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
......
......@@ -24,7 +24,10 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available():
import torch
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
ElectraConfig,
ElectraForMaskedLM,
ElectraForMultipleChoice,
......@@ -285,6 +288,17 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
return inputs_dict
def setUp(self):
self.model_tester = ElectraModelTester(self)
self.config_tester = ConfigTester(self, config_class=ElectraConfig, hidden_size=37)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment