"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4c14669a78856dae9bf53e5708935ba397f5e689"
Unverified Commit 8bf73126 authored by Jared T Nielsen's avatar Jared T Nielsen Committed by GitHub
Browse files

Add AlbertForPreTraining and TFAlbertForPreTraining models. (#4057)



* Add AlbertForPreTraining and TFAlbertForPreTraining models.

* PyTorch conversion

* TensorFlow conversion

* style
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent c99fe038
...@@ -287,6 +287,7 @@ if is_torch_available(): ...@@ -287,6 +287,7 @@ if is_torch_available():
from .modeling_albert import ( from .modeling_albert import (
AlbertPreTrainedModel, AlbertPreTrainedModel,
AlbertModel, AlbertModel,
AlbertForPreTraining,
AlbertForMaskedLM, AlbertForMaskedLM,
AlbertForSequenceClassification, AlbertForSequenceClassification,
AlbertForQuestionAnswering, AlbertForQuestionAnswering,
...@@ -490,6 +491,7 @@ if is_tf_available(): ...@@ -490,6 +491,7 @@ if is_tf_available():
TFAlbertPreTrainedModel, TFAlbertPreTrainedModel,
TFAlbertMainLayer, TFAlbertMainLayer,
TFAlbertModel, TFAlbertModel,
TFAlbertForPreTraining,
TFAlbertForMaskedLM, TFAlbertForMaskedLM,
TFAlbertForSequenceClassification, TFAlbertForSequenceClassification,
TFAlbertForQuestionAnswering, TFAlbertForQuestionAnswering,
......
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import torch import torch
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -30,7 +30,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt ...@@ -30,7 +30,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
# Initialise PyTorch model # Initialise PyTorch model
config = AlbertConfig.from_json_file(albert_config_file) config = AlbertConfig.from_json_file(albert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = AlbertForMaskedLM(config) model = AlbertForPreTraining(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path) load_tf_weights_in_albert(model, config, tf_checkpoint_path)
......
...@@ -46,7 +46,7 @@ from transformers import ( ...@@ -46,7 +46,7 @@ from transformers import (
OpenAIGPTConfig, OpenAIGPTConfig,
RobertaConfig, RobertaConfig,
T5Config, T5Config,
TFAlbertForMaskedLM, TFAlbertForPreTraining,
TFBertForPreTraining, TFBertForPreTraining,
TFBertForQuestionAnswering, TFBertForQuestionAnswering,
TFBertForSequenceClassification, TFBertForSequenceClassification,
...@@ -109,7 +109,7 @@ if is_torch_available(): ...@@ -109,7 +109,7 @@ if is_torch_available():
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, AlbertForPreTraining,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5ForConditionalGeneration, T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -148,7 +148,7 @@ else: ...@@ -148,7 +148,7 @@ else:
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, AlbertForPreTraining,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5ForConditionalGeneration, T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -318,8 +318,8 @@ MODEL_CLASSES = { ...@@ -318,8 +318,8 @@ MODEL_CLASSES = {
), ),
"albert": ( "albert": (
AlbertConfig, AlbertConfig,
TFAlbertForMaskedLM, TFAlbertForPreTraining,
AlbertForMaskedLM, AlbertForPreTraining,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
), ),
......
...@@ -111,7 +111,8 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -111,7 +111,8 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
# No ALBERT model currently handles the next sentence prediction task # No ALBERT model currently handles the next sentence prediction task
if "seq_relationship" in name: if "seq_relationship" in name:
continue name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
name = name.replace("weights", "weight")
name = name.split("/") name = name.split("/")
...@@ -568,6 +569,115 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -568,6 +569,115 @@ class AlbertModel(AlbertPreTrainedModel):
return outputs return outputs
@add_start_docstrings(
"""Albert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
a `sentence order prediction (classification)` head. """,
ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.albert = AlbertModel(config)
self.predictions = AlbertMLMHead(config)
self.sop_classifier = AlbertSOPHead(config)
self.init_weights()
self.tie_weights()
def tie_weights(self):
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
def get_output_embeddings(self):
return self.predictions.decoder
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
sentence_order_label=None,
):
r"""
masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
Indices should be in ``[0, 1]``.
``0`` indicates original order (sequence A, then sequence B),
``1`` indicates switched order (sequence B, then sequence A).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
sop_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import AlbertTokenizer, AlbertForPreTraining
import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForPreTraining.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
prediction_scores, sop_scores = outputs[:2]
"""
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.predictions(sequence_output)
sop_scores = self.sop_classifier(pooled_output)
outputs = (prediction_scores, sop_scores,) + outputs[2:] # add hidden states and attention if they are here
if masked_lm_labels is not None and sentence_order_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
total_loss = masked_lm_loss + sentence_order_loss
outputs = (total_loss,) + outputs
return outputs # (loss), prediction_scores, sop_scores, (hidden_states), (attentions)
class AlbertMLMHead(nn.Module): class AlbertMLMHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -592,6 +702,19 @@ class AlbertMLMHead(nn.Module): ...@@ -592,6 +702,19 @@ class AlbertMLMHead(nn.Module):
return prediction_scores return prediction_scores
class AlbertSOPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, pooled_output):
dropout_pooled_output = self.dropout(pooled_output)
logits = self.classifier(dropout_pooled_output)
return logits
@add_start_docstrings( @add_start_docstrings(
"Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, "Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING,
) )
......
...@@ -43,6 +43,7 @@ from .configuration_utils import PretrainedConfig ...@@ -43,6 +43,7 @@ from .configuration_utils import PretrainedConfig
from .modeling_albert import ( from .modeling_albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, AlbertForMaskedLM,
AlbertForPreTraining,
AlbertForQuestionAnswering, AlbertForQuestionAnswering,
AlbertForSequenceClassification, AlbertForSequenceClassification,
AlbertForTokenClassification, AlbertForTokenClassification,
...@@ -189,7 +190,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -189,7 +190,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[ [
(T5Config, T5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM), (DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM), (AlbertConfig, AlbertForPreTraining),
(CamembertConfig, CamembertForMaskedLM), (CamembertConfig, CamembertForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM),
(BartConfig, BartForConditionalGeneration), (BartConfig, BartForConditionalGeneration),
......
...@@ -475,7 +475,6 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -475,7 +475,6 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states) hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.decoder(hidden_states, mode="linear") + self.decoder_bias hidden_states = self.decoder(hidden_states, mode="linear") + self.decoder_bias
hidden_states = hidden_states + self.bias
return hidden_states return hidden_states
...@@ -718,6 +717,73 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -718,6 +717,73 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
return outputs return outputs
@add_start_docstrings(
"""Albert Model with two heads on top for pre-training:
a `masked language modeling` head and a `sentence order prediction` (classification) head. """,
ALBERT_START_DOCSTRING,
)
class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.albert = TFAlbertMainLayer(config, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
def get_output_embeddings(self):
return self.albert.embeddings
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
sop_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`):
Prediction scores of the sentence order prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import AlbertTokenizer, TFAlbertForPreTraining
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertForPreTraining.from_pretrained('albert-base-v2')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
prediction_scores, sop_scores = outputs[:2]
"""
outputs = self.albert(inputs, **kwargs)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.predictions(sequence_output)
sop_scores = self.sop_classifier(pooled_output, training=kwargs.get("training", False))
outputs = (prediction_scores, sop_scores) + outputs[2:]
return outputs
class TFAlbertSOPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier",
)
def call(self, pooled_output, training: bool):
dropout_pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(dropout_pooled_output)
return logits
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING) @add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
......
...@@ -36,6 +36,7 @@ from .configuration_utils import PretrainedConfig ...@@ -36,6 +36,7 @@ from .configuration_utils import PretrainedConfig
from .modeling_tf_albert import ( from .modeling_tf_albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TFAlbertForMaskedLM, TFAlbertForMaskedLM,
TFAlbertForPreTraining,
TFAlbertForQuestionAnswering, TFAlbertForQuestionAnswering,
TFAlbertForSequenceClassification, TFAlbertForSequenceClassification,
TFAlbertModel, TFAlbertModel,
...@@ -132,7 +133,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -132,7 +133,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[ [
(T5Config, TFT5ForConditionalGeneration), (T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM), (DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM), (AlbertConfig, TFAlbertForPreTraining),
(RobertaConfig, TFRobertaForMaskedLM), (RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForPreTraining), (BertConfig, TFBertForPreTraining),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
...@@ -412,7 +413,7 @@ class TFAutoModelForPreTraining(object): ...@@ -412,7 +413,7 @@ class TFAutoModelForPreTraining(object):
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: :class:`~transformers.TFT5ModelWithLMHead` (T5 model) - contains `t5`: :class:`~transformers.TFT5ModelWithLMHead` (T5 model)
- contains `distilbert`: :class:`~transformers.TFDistilBertForMaskedLM` (DistilBERT model) - contains `distilbert`: :class:`~transformers.TFDistilBertForMaskedLM` (DistilBERT model)
- contains `albert`: :class:`~transformers.TFAlbertForMaskedLM` (ALBERT model) - contains `albert`: :class:`~transformers.TFAlbertForPreTraining` (ALBERT model)
- contains `roberta`: :class:`~transformers.TFRobertaForMaskedLM` (RoBERTa model) - contains `roberta`: :class:`~transformers.TFRobertaForMaskedLM` (RoBERTa model)
- contains `bert`: :class:`~transformers.TFBertForPreTraining` (Bert model) - contains `bert`: :class:`~transformers.TFBertForPreTraining` (Bert model)
- contains `openai-gpt`: :class:`~transformers.TFOpenAIGPTLMHeadModel` (OpenAI GPT model) - contains `openai-gpt`: :class:`~transformers.TFOpenAIGPTLMHeadModel` (OpenAI GPT model)
......
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
from transformers import ( from transformers import (
AlbertConfig, AlbertConfig,
AlbertModel, AlbertModel,
AlbertForPreTraining,
AlbertForMaskedLM, AlbertForMaskedLM,
AlbertForSequenceClassification, AlbertForSequenceClassification,
AlbertForTokenClassification, AlbertForTokenClassification,
...@@ -38,7 +39,7 @@ if is_torch_available(): ...@@ -38,7 +39,7 @@ if is_torch_available():
@require_torch @require_torch
class AlbertModelTest(ModelTesterMixin, unittest.TestCase): class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else () all_model_classes = (AlbertModel, AlbertForPreTraining, AlbertForMaskedLM) if is_torch_available() else ()
class AlbertModelTester(object): class AlbertModelTester(object):
def __init__( def __init__(
...@@ -151,6 +152,30 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -151,6 +152,30 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = AlbertForPreTraining(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores, sop_scores = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
sentence_order_label=sequence_labels,
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"sop_scores": sop_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["sop_scores"].size()), [self.batch_size, config.num_labels])
self.check_loss_output(result)
def create_and_check_albert_for_masked_lm( def create_and_check_albert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -252,6 +277,10 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -252,6 +277,10 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_model(*config_and_inputs) self.model_tester.create_and_check_albert_model(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_pretraining(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs)
......
...@@ -26,6 +26,7 @@ from .utils import require_tf, slow ...@@ -26,6 +26,7 @@ from .utils import require_tf, slow
if is_tf_available(): if is_tf_available():
from transformers.modeling_tf_albert import ( from transformers.modeling_tf_albert import (
TFAlbertModel, TFAlbertModel,
TFAlbertForPreTraining,
TFAlbertForMaskedLM, TFAlbertForMaskedLM,
TFAlbertForSequenceClassification, TFAlbertForSequenceClassification,
TFAlbertForQuestionAnswering, TFAlbertForQuestionAnswering,
...@@ -37,7 +38,13 @@ if is_tf_available(): ...@@ -37,7 +38,13 @@ if is_tf_available():
class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, TFAlbertForQuestionAnswering) (
TFAlbertModel,
TFAlbertForPreTraining,
TFAlbertForMaskedLM,
TFAlbertForSequenceClassification,
TFAlbertForQuestionAnswering,
)
if is_tf_available() if is_tf_available()
else () else ()
) )
...@@ -153,6 +160,22 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -153,6 +160,22 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
) )
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFAlbertForPreTraining(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
prediction_scores, sop_scores = model(inputs)
result = {
"prediction_scores": prediction_scores.numpy(),
"sop_scores": sop_scores.numpy(),
}
self.parent.assertListEqual(
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["sop_scores"].shape), [self.batch_size, self.num_labels])
def create_and_check_albert_for_masked_lm( def create_and_check_albert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -216,6 +239,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -216,6 +239,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_model(*config_and_inputs) self.model_tester.create_and_check_albert_model(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_pretraining(*config_and_inputs)
def test_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment