Commit 727a79b3 authored by thomwolf's avatar thomwolf
Browse files

added TF2 model and tests - updated templates

parent 8fda532c
...@@ -26,6 +26,8 @@ import logging ...@@ -26,6 +26,8 @@ import logging
import math import math
import os import os
import sys import sys
import copy
import itertools
from io import open from io import open
import numpy as np import numpy as np
......
...@@ -25,6 +25,8 @@ import logging ...@@ -25,6 +25,8 @@ import logging
import math import math
import os import os
import sys import sys
import copy
import itertools
from io import open from io import open
import torch import torch
......
...@@ -158,6 +158,9 @@ if is_tf_available(): ...@@ -158,6 +158,9 @@ if is_tf_available():
TFCTRLLMHeadModel, TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_t5 import (TFT5PreTrainedModel, TFT5Model, TFT5WithLMHeadModel,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP)
# TF 2.0 <=> PyTorch conversion utilities # TF 2.0 <=> PyTorch conversion utilities
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
load_pytorch_checkpoint_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
......
...@@ -27,6 +27,7 @@ from .configuration_xlm import XLMConfig ...@@ -27,6 +27,7 @@ from .configuration_xlm import XLMConfig
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .configuration_t5 import T5Config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -64,6 +65,7 @@ class AutoConfig(object): ...@@ -64,6 +65,7 @@ class AutoConfig(object):
The configuration class to instantiate is selected as the first pattern matching The configuration class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `bert`: BertConfig (Bert model) - contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
...@@ -114,7 +116,9 @@ class AutoConfig(object): ...@@ -114,7 +116,9 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
if 'distilbert' in pretrained_model_name_or_path: if 't5' in pretrained_model_name_or_path:
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
......
...@@ -27,8 +27,7 @@ from .configuration_utils import PretrainedConfig ...@@ -27,8 +27,7 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = { T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
't5-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-config.json", 't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
't5-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-config.json",
} }
......
...@@ -41,8 +41,7 @@ logger = logging.getLogger(__name__) ...@@ -41,8 +41,7 @@ logger = logging.getLogger(__name__)
# for the pretrained weights provided with the models # for the pretrained weights provided with the models
#################################################### ####################################################
T5_PRETRAINED_MODEL_ARCHIVE_MAP = { T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
't5-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-pytorch_model.bin", 't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-pytorch_model.bin",
't5-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-pytorch_model.bin",
} }
#################################################### ####################################################
...@@ -442,7 +441,7 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -442,7 +441,7 @@ class T5PreTrainedModel(PreTrainedModel):
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(factor*1.0) module.weight.data.fill_(factor*1.0)
elif isinstance(module, T5Model): elif isinstance(module, (T5Model, T5WithLMHeadModel)):
# Mesh TensorFlow embeddings initialization # Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor*1.0) module.shared.weight.data.normal_(mean=0.0, std=factor*1.0)
...@@ -502,11 +501,10 @@ class T5Stack(T5PreTrainedModel): ...@@ -502,11 +501,10 @@ class T5Stack(T5PreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3: if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :] extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length] # Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2:
if self.config.is_decoder: if self.config.is_decoder:
seq_ids = torch.arange(seq_length, device=hidden_states.device) seq_ids = torch.arange(seq_length, device=hidden_states.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
...@@ -593,7 +591,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -593,7 +591,7 @@ class T5Stack(T5PreTrainedModel):
T5_START_DOCSTRING = r""" The T5 model was proposed in T5_START_DOCSTRING = r""" The T5 model was proposed in
`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_ `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_
by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
It's an encoder decoder pre-trained transformer. It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior. refer to the PyTorch documentation for all matter related to general usage and behavior.
...@@ -634,16 +632,13 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -634,16 +632,13 @@ T5_INPUTS_DOCSTRING = r"""
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
""" """
@add_start_docstrings("The bare single stack (encoder or decoder) of a T5 Model transformer outputting raw hidden-states" @add_start_docstrings("The bare T5 Model transformer outputting raw hidden-states"
"without any specific head on top.", "without any specific head on top.",
T5_START_DOCSTRING, T5_INPUTS_DOCSTRING) T5_START_DOCSTRING, T5_INPUTS_DOCSTRING)
class T5Model(T5PreTrainedModel): class T5Model(T5PreTrainedModel):
...@@ -661,8 +656,8 @@ class T5Model(T5PreTrainedModel): ...@@ -661,8 +656,8 @@ class T5Model(T5PreTrainedModel):
Examples:: Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-base-uncased') tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5Model.from_pretrained('t5-base-uncased') model = T5Model.from_pretrained('t5-small')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
...@@ -752,8 +747,8 @@ class T5WithLMHeadModel(T5PreTrainedModel): ...@@ -752,8 +747,8 @@ class T5WithLMHeadModel(T5PreTrainedModel):
Examples:: Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-base-uncased') tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5WithLMHeadModel.from_pretrained('t5-base-uncased') model = T5WithLMHeadModel.from_pretrained('t5-small')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, lm_labels=input_ids) outputs = model(input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2] loss, prediction_scores = outputs[:2]
...@@ -763,31 +758,73 @@ class T5WithLMHeadModel(T5PreTrainedModel): ...@@ -763,31 +758,73 @@ class T5WithLMHeadModel(T5PreTrainedModel):
super(T5WithLMHeadModel, self).__init__(config) super(T5WithLMHeadModel, self).__init__(config)
self.model_dim = config.d_model self.model_dim = config.d_model
self.transformer = T5Model(config) self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
self.encoder = T5Stack(encoder_config)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
self.decoder = T5Stack(decoder_config)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, **kwargs): def forward(self, **kwargs):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
lm_labels = kwargs.pop('decoder_lm_labels', None) lm_labels = kwargs.pop('decoder_lm_labels', None)
outputs = self.transformer(**kwargs)
sequence_output = outputs[0] kwargs_common = dict((k, v) for k, v in kwargs.items()
if not k.startswith("encoder_") and not k.startswith("decoder_"))
kwargs_encoder = kwargs_common.copy()
kwargs_decoder = kwargs_common.copy()
kwargs_encoder.update(dict((k[len("encoder_"):], v) for k, v in kwargs.items() if k.startswith("encoder_")))
kwargs_decoder.update(dict((k[len("decoder_"):], v) for k, v in kwargs.items() if k.startswith("decoder_")))
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
# Decode
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
sequence_output = decoder_outputs[0]
# Rescale output before projecting on vocab # Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim ** -0.5) sequence_output = sequence_output * (self.model_dim ** -0.5)
lm_logits = self.lm_head(sequence_output) lm_logits = self.lm_head(sequence_output)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = (loss,) + outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 decoder_outputs = (loss,) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
return outputs # (lm_loss), lm_logits, (hidden_states), (attentions) return decoder_outputs + encoder_outputs
...@@ -156,7 +156,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -156,7 +156,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
e.args += (symbolic_weight.shape, array.shape) e.args += (symbolic_weight.shape, array.shape)
raise e raise e
logger.info("Initialize TF weight {}".format(symbolic_weight.name)) # logger.warning("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array)) weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name) all_pytorch_weights.discard(name)
...@@ -269,7 +269,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -269,7 +269,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
e.args += (pt_weight.shape, array.shape) e.args += (pt_weight.shape, array.shape)
raise e raise e
logger.info("Initialize PyTorch weight {}".format(pt_weight_name)) # logger.warning("Initialize PyTorch weight {}".format(pt_weight_name))
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array) new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array) loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
......
This diff is collapsed.
...@@ -160,8 +160,7 @@ class PreTrainedModel(nn.Module): ...@@ -160,8 +160,7 @@ class PreTrainedModel(nn.Module):
base_model.vocab_size = new_num_tokens base_model.vocab_size = new_num_tokens
# Tie weights again if needed # Tie weights again if needed
if hasattr(self, 'tie_weights'): self.tie_weights()
self.tie_weights()
return model_embeds return model_embeds
...@@ -458,8 +457,7 @@ class PreTrainedModel(nn.Module): ...@@ -458,8 +457,7 @@ class PreTrainedModel(nn.Module):
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs))) model.__class__.__name__, "\n\t".join(error_msgs)))
if hasattr(model, 'tie_weights'): model.tie_weights() # make sure word embedding weights are still tied if needed
model.tie_weights() # make sure word embedding weights are still tied
# Set model in evaluation mode to desactivate DropOut modules by default # Set model in evaluation mode to desactivate DropOut modules by default
model.eval() model.eval()
......
...@@ -69,6 +69,7 @@ class TFCommonTestCases: ...@@ -69,6 +69,7 @@ class TFCommonTestCases:
test_torchscript = True test_torchscript = True
test_pruning = True test_pruning = True
test_resize_embeddings = True test_resize_embeddings = True
is_encoder_decoder = False
def test_initialization(self): def test_initialization(self):
pass pass
...@@ -156,7 +157,11 @@ class TFCommonTestCases: ...@@ -156,7 +157,11 @@ class TFCommonTestCases:
def test_compile_tf_model(self): def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = tf.keras.Input(batch_shape=(2, 2000), name='input_ids', dtype='int32') if self.is_encoder_decoder:
input_ids = {'decoder_input_ids': tf.keras.Input(batch_shape=(2, 2000), name='decoder_input_ids', dtype='int32'),
'encoder_input_ids': tf.keras.Input(batch_shape=(2, 2000), name='encoder_input_ids', dtype='int32')}
else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name='input_ids', dtype='int32')
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
...@@ -189,7 +194,7 @@ class TFCommonTestCases: ...@@ -189,7 +194,7 @@ class TFCommonTestCases:
outputs_dict = model(inputs_dict) outputs_dict = model(inputs_dict)
inputs_keywords = copy.deepcopy(inputs_dict) inputs_keywords = copy.deepcopy(inputs_dict)
input_ids = inputs_keywords.pop('input_ids') input_ids = inputs_keywords.pop('input_ids', inputs_keywords.pop('decoder_input_ids'))
outputs_keywords = model(input_ids, **inputs_keywords) outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy() output_dict = outputs_dict[0].numpy()
...@@ -216,12 +221,24 @@ class TFCommonTestCases: ...@@ -216,12 +221,24 @@ class TFCommonTestCases:
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length]) self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
out_len = len(outputs) out_len = len(outputs)
if self.is_encoder_decoder:
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs[(out_len // 2)-1]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# Check attention is always last and order is fine # Check attention is always last and order is fine
config.output_attentions = True config.output_attentions = True
config.output_hidden_states = True config.output_hidden_states = True
model = model_class(config) model = model_class(config)
outputs = model(inputs_dict) outputs = model(inputs_dict)
self.assertEqual(out_len+1, len(outputs)) self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
......
...@@ -26,7 +26,7 @@ from .configuration_common_test import ConfigTester ...@@ -26,7 +26,7 @@ from .configuration_common_test import ConfigTester
from transformers import T5Config, is_tf_available from transformers import T5Config, is_tf_available
if False: # is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers.modeling_tf_t5 import (TFT5Model, TFT5WithLMHeadModel,TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP) from transformers.modeling_tf_t5 import (TFT5Model, TFT5WithLMHeadModel,TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -35,7 +35,8 @@ else: ...@@ -35,7 +35,8 @@ else:
class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if False else () # is_tf_available() else () is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else ()
class TFT5ModelTester(object): class TFT5ModelTester(object):
...@@ -45,22 +46,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -45,22 +46,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
n_positions=14,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, d_ff=37,
hidden_act="gelu", relative_attention_num_buckets=8,
hidden_dropout_prob=0.1, dropout_rate=0.1,
attention_probs_dropout_prob=0.1, initializer_factor=0.002,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
...@@ -68,22 +63,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -68,22 +63,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
self.seq_length = seq_length self.seq_length = seq_length
self.is_training = is_training self.is_training = is_training
self.use_input_mask = use_input_mask self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.d_ff = d_ff
self.hidden_act = hidden_act self.relative_attention_num_buckets = relative_attention_num_buckets
self.hidden_dropout_prob = hidden_dropout_prob self.dropout_rate = dropout_rate
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.initializer_factor = initializer_factor
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
...@@ -93,61 +82,53 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -93,61 +82,53 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None token_labels = None
choice_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = T5Config( config = T5Config(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size, n_positions=self.n_positions,
num_hidden_layers=self.num_hidden_layers, d_model=self.hidden_size,
num_attention_heads=self.num_attention_heads, d_ff=self.d_ff,
intermediate_size=self.intermediate_size, d_kv=self.hidden_size // self.num_attention_heads,
hidden_act=self.hidden_act, num_layers=self.num_hidden_layers,
hidden_dropout_prob=self.hidden_dropout_prob, num_heads=self.num_attention_heads,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, relative_attention_num_buckets=self.relative_attention_num_buckets,
max_position_embeddings=self.max_position_embeddings, dropout_rate=self.dropout_rate,
type_vocab_size=self.type_vocab_size, initializer_factor=self.initializer_factor)
initializer_range=self.initializer_range)
return (config, input_ids, input_mask, token_labels)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
def create_and_check_t5_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = TFT5Model(config=config) model = TFT5Model(config=config)
inputs = {'input_ids': input_ids, inputs = {'encoder_input_ids': input_ids,
'attention_mask': input_mask, 'decoder_input_ids': input_ids,
'token_type_ids': token_type_ids} 'decoder_attention_mask': input_mask}
sequence_output, pooled_output = model(inputs) encoder_output, decoder_output = model(inputs)
inputs = [input_ids, input_mask]
sequence_output, pooled_output = model(inputs)
sequence_output, pooled_output = model(input_ids) encoder_output, decoder_output = model(input_ids,
decoder_attention_mask=input_mask,
encoder_input_ids=input_ids)
result = { result = {
"sequence_output": sequence_output.numpy(), "encoder_output": encoder_output.numpy(),
"pooled_output": pooled_output.numpy(), "decoder_output": decoder_output.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].shape), list(result["encoder_output"].shape),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(result["decoder_output"].shape),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
def create_and_check_t5_with_lm_head(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5WithLMHeadModel(config=config) model = TFT5WithLMHeadModel(config=config)
inputs = {'input_ids': input_ids, inputs = {'encoder_input_ids': input_ids,
'attention_mask': input_mask, 'decoder_input_ids': input_ids,
'token_type_ids': token_type_ids} 'decoder_attention_mask': input_mask}
prediction_scores, = model(inputs) prediction_scores, decoder_output = model(inputs)
result = { result = {
"prediction_scores": prediction_scores.numpy(), "prediction_scores": prediction_scores.numpy(),
} }
...@@ -158,14 +139,15 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -158,14 +139,15 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask, (config, input_ids, input_mask, token_labels) = config_and_inputs
sequence_labels, token_labels, choice_labels) = config_and_inputs inputs_dict = {'encoder_input_ids': input_ids,
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} 'decoder_input_ids': input_ids,
'decoder_attention_mask': input_mask}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
self.model_tester = TFT5ModelTest.TFT5ModelTester(self) self.model_tester = TFT5ModelTest.TFT5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=T5Config, hidden_size=37) self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
...@@ -181,7 +163,7 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -181,7 +163,7 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in ['t5-base']: for model_name in ['t5-small']:
model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir) model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment