Unverified Commit bbf26c4e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Support T5 Generation (#3228)



* fix conflicts

* update bart max length test

* correct spelling mistakes

* implemented model specific encode function

* fix merge conflicts

* better naming

* save intermediate state -> need to rethink strucuture a bit

* leave tf problem as it is for now

* current version

* add layers.pop

* remove ipdb

* make style

* clean return cut decoding

* remove ipdbs

* Fix restoring layers in the decoders that doesnt exists.

* push good intermediate solution for now

* fix conflicts

* always good to refuse to merge conflicts when rebasing

* fix small bug

* improve function calls

* remove unused file

* add correct scope behavior for t5_generate
Co-authored-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>
parent 656e1386
......@@ -255,7 +255,7 @@ if is_torch_available():
from .modeling_t5 import (
T5PreTrainedModel,
T5Model,
T5WithLMHeadModel,
T5ForConditionalGeneration,
load_tf_weights_in_t5,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
)
......@@ -444,7 +444,7 @@ if is_tf_available():
from .modeling_tf_t5 import (
TFT5PreTrainedModel,
TFT5Model,
TFT5WithLMHeadModel,
TFT5ForConditionalGeneration,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
)
......
......@@ -76,6 +76,8 @@ class T5Config(PretrainedConfig):
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
is_encoder_decoder=True,
pad_token_id=0,
eos_token_ids=[1],
**kwargs
):
super().__init__(
......
......@@ -57,7 +57,7 @@ from transformers import (
TFOpenAIGPTLMHeadModel,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFT5WithLMHeadModel,
TFT5ForConditionalGeneration,
TFTransfoXLLMHeadModel,
TFXLMRobertaForMaskedLM,
TFXLMWithLMHeadModel,
......@@ -108,7 +108,7 @@ if is_torch_available():
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
)
else:
......@@ -145,7 +145,7 @@ else:
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
) = (
None,
......@@ -316,8 +316,8 @@ MODEL_CLASSES = {
),
"t5": (
T5Config,
TFT5WithLMHeadModel,
T5WithLMHeadModel,
TFT5ForConditionalGeneration,
T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
......
......@@ -93,7 +93,7 @@ from .modeling_roberta import (
RobertaForTokenClassification,
RobertaModel,
)
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5ForConditionalGeneration, T5Model
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
......@@ -166,7 +166,7 @@ MODEL_MAPPING = OrderedDict(
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(T5Config, T5WithLMHeadModel),
(T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM),
(CamembertConfig, CamembertForMaskedLM),
......@@ -186,7 +186,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
(T5Config, T5WithLMHeadModel),
(T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM),
(CamembertConfig, CamembertForMaskedLM),
......
......@@ -885,18 +885,17 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
attention_mask.shape, encoder_inputs.shape
)
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step, decoder_cached_states are empty
if not past[1]:
encoder_outputs, decoder_cached_states = past, None
else:
encoder_outputs, decoder_cached_states = past
input_ids = encoder_inputs
return {
"input_ids": input_ids, # ignored after first pass
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
......@@ -929,6 +928,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
past = ((new_enc_out, new_enc_mask), reordered_past)
return past
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return self.lm_head
......
......@@ -464,7 +464,7 @@ class T5PreTrainedModel(PreTrainedModel):
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"decoder_input_ids": input_ids,
"encoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
......@@ -474,7 +474,7 @@ class T5PreTrainedModel(PreTrainedModel):
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, T5LayerNorm):
module.weight.data.fill_(factor * 1.0)
elif isinstance(module, (T5Model, T5WithLMHeadModel)):
elif isinstance(module, (T5Model, T5ForConditionalGeneration)):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
......@@ -503,10 +503,12 @@ class T5PreTrainedModel(PreTrainedModel):
class T5Stack(T5PreTrainedModel):
def __init__(self, config):
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.block = nn.ModuleList(
......@@ -517,21 +519,46 @@ class T5Stack(T5PreTrainedModel):
self.init_weights()
def get_input_embeddings(self):
return self.embed_tokens
def get_output_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
def forward(
self,
hidden_states,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
):
batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length).to(hidden_states.device)
attention_mask = torch.ones(batch_size, seq_length).to(inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(hidden_states.device)
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
......@@ -542,7 +569,7 @@ class T5Stack(T5PreTrainedModel):
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
seq_ids = torch.arange(seq_length, device=hidden_states.device)
seq_ids = torch.arange(seq_length, device=inputs_embeds.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(attention_mask)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
......@@ -605,7 +632,7 @@ class T5Stack(T5PreTrainedModel):
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(inputs_embeds)
for i, layer_module in enumerate(self.block):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -731,11 +758,11 @@ class T5Model(T5PreTrainedModel):
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
self.encoder = T5Stack(encoder_config)
self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
self.decoder = T5Stack(decoder_config)
self.decoder = T5Stack(decoder_config, self.shared)
self.init_weights()
......@@ -744,6 +771,8 @@ class T5Model(T5PreTrainedModel):
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
......@@ -753,55 +782,41 @@ class T5Model(T5PreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, **kwargs):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common = dict(
(k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_")
)
kwargs_encoder = kwargs_common.copy()
kwargs_decoder = kwargs_common.copy()
kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_")))
kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_")))
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
head_mask=None,
):
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
encoder_attention_mask = kwargs_encoder.get("attention_mask", None)
if encoder_hidden_states is None:
# Convert encoder inputs in embeddings if needed
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
if hidden_states is None:
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
if encoder_attention_mask is not None:
# Apply masking
encoder_attention_mask = (encoder_attention_mask != 0).to(hidden_states)
hidden_states = hidden_states * encoder_attention_mask.unsqueeze(-1)
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask
)
# Decode
# Convert decoder inputs in embeddings if needed
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
if hidden_states is None:
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
hidden_states = self.shared(decoder_inputs_ids)
hidden_states = encoder_outputs[0]
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = encoder_attention_mask
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
)
return decoder_outputs + encoder_outputs
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING, T5_INPUTS_DOCSTRING)
class T5WithLMHeadModel(T5PreTrainedModel):
class T5ForConditionalGeneration(T5PreTrainedModel):
r"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
......@@ -825,7 +840,7 @@ class T5WithLMHeadModel(T5PreTrainedModel):
Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5WithLMHeadModel.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids=input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
......@@ -839,11 +854,11 @@ class T5WithLMHeadModel(T5PreTrainedModel):
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
self.encoder = T5Stack(encoder_config)
self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
self.decoder = T5Stack(decoder_config)
self.decoder = T5Stack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
......@@ -854,50 +869,46 @@ class T5WithLMHeadModel(T5PreTrainedModel):
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def get_output_embeddings(self):
return self.lm_head
def forward(self, **kwargs):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
lm_labels = kwargs.pop("decoder_lm_labels", None)
def get_encoder(self):
return self.encoder
kwargs_common = dict(
(k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_")
)
kwargs_encoder = kwargs_common.copy()
kwargs_decoder = kwargs_common.copy()
kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_")))
kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_")))
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
lm_labels=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
head_mask=None,
):
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
if hidden_states is None:
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
encoder_outputs = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask
)
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
hidden_states = encoder_outputs[0]
# Decode
# Convert decoder inputs in embeddings if needed
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
if hidden_states is None:
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
hidden_states = self.shared(decoder_inputs_ids)
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
)
sequence_output = decoder_outputs[0]
# Rescale output before projecting on vocab
......@@ -916,3 +927,22 @@ class T5WithLMHeadModel(T5PreTrainedModel):
) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
if type(past) is tuple:
encoder_outputs = past
else:
encoder_outputs = (past,)
return {
"decoder_input_ids": input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
}
def _reorder_cache(self, past, beam_idx):
# past does not have to be re-ordered for T5.
return past
......@@ -66,7 +66,7 @@ from .modeling_tf_roberta import (
TFRobertaForTokenClassification,
TFRobertaModel,
)
from .modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, TFT5Model, TFT5WithLMHeadModel
from .modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, TFT5ForConditionalGeneration, TFT5Model
from .modeling_tf_transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TFTransfoXLLMHeadModel,
......@@ -128,7 +128,7 @@ TF_MODEL_MAPPING = OrderedDict(
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(T5Config, TFT5WithLMHeadModel),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
......@@ -144,7 +144,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
(T5Config, TFT5WithLMHeadModel),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
......@@ -507,7 +507,7 @@ class TFAutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: TFT5WithLMHeadModel (T5 model)
- contains `t5`: TFT5ForConditionalGeneration (T5 model)
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
- contains `bert`: TFBertForMaskedLM (Bert model)
......@@ -571,7 +571,7 @@ class TFAutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: TFT5WithLMHeadModel (T5 model)
- contains `t5`: TFT5ForConditionalGeneration (T5 model)
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
- contains `bert`: TFBertForMaskedLM (Bert model)
......
......@@ -160,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if name not in pt_state_dict:
if allow_missing_keys:
continue
raise AttributeError("{} not found in PyTorch model".format(name))
array = pt_state_dict[name].numpy()
......@@ -288,6 +289,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
if allow_missing_keys:
missing_keys_pt.append(pt_weight_name)
continue
raise AttributeError("{} not found in TF 2.0 model".format(pt_weight_name))
array, transpose = tf_weights_map[pt_weight_name]
......
......@@ -355,16 +355,49 @@ class TFT5Block(tf.keras.layers.Layer):
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class _NoLayerEmbedTokens(object):
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
class to avoid problem with weight restoring. Also it makes sure that the layer is
called from the correct scope to avoid problem with saving/storing the correct weights
"""
def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name
def call(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer.call(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer.call(inputs, mode)
def __call__(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer(inputs, mode)
####################################################
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config, embed_tokens=None, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.config = config
self.num_hidden_layers = config.num_layers
......@@ -375,6 +408,15 @@ class TFT5MainLayer(tf.keras.layers.Layer):
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def get_input_embeddings(self):
return self.embed_tokens
def get_output_embeddings(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
......@@ -383,15 +425,31 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def call(
self,
hidden_states,
input_ids,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
training=False,
):
batch_size, seq_length = shape_list(hidden_states)[:2]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
if attention_mask is None:
attention_mask = tf.fill((batch_size, seq_length), 1)
if self.is_decoder and encoder_attention_mask is None:
......@@ -465,6 +523,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
all_attentions = ()
position_bias = None
encoder_decoder_position_bias = None
hidden_states = inputs_embeds
for i, layer_module in enumerate(self.block):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -527,7 +587,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
input_mask = tf.constant(DUMMY_MASK)
dummy_inputs = {
"decoder_input_ids": input_ids,
"encoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
......@@ -636,12 +696,18 @@ class TFT5Model(TFT5PreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
encoder_config = copy.deepcopy(config)
self.encoder = TFT5MainLayer(encoder_config, name="encoder")
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
self.decoder = TFT5MainLayer(decoder_config, name="decoder")
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
def get_input_embeddings(self):
return self.shared
......@@ -650,54 +716,45 @@ class TFT5Model(TFT5PreTrainedModel):
return self.shared
def call(self, decoder_input_ids, **kwargs):
# We allow two types of multi-inputs:
# - traditional keyword arguments in the call method
# - all the arguments provided as a dict in the first positional argument of call
# The last option is useful to use the tf.keras fit() method.
if isinstance(decoder_input_ids, dict):
kwargs.update(decoder_input_ids)
else:
kwargs["decoder_input_ids"] = decoder_input_ids
kwargs_common = dict(
(k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_")
)
kwargs_encoder = kwargs_common.copy()
kwargs_decoder = kwargs_common.copy()
kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_")))
kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_")))
# retrieve arguments
input_ids = kwargs.get("input_ids", None)
decoder_input_ids = kwargs.get("decoder_input_ids", None)
attention_mask = kwargs.get("attention_mask", None)
encoder_outputs = kwargs.get("encoder_outputs", None)
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
inputs_embeds = kwargs.get("inputs_embeds", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
head_mask = kwargs.get("head_mask", None)
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
# Convert encoder inputs in embeddings if needed
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
if hidden_states is None:
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask
)
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
hidden_states = encoder_outputs[0]
# Decode
# Convert decoder inputs in embeddings if needed
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
if hidden_states is None:
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
hidden_states = self.shared(decoder_inputs_ids)
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
)
return decoder_outputs + encoder_outputs
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING, T5_INPUTS_DOCSTRING)
class TFT5WithLMHeadModel(TFT5PreTrainedModel):
class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
......@@ -713,10 +770,10 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel):
Examples::
import tensorflow as tf
from transformers import T5Tokenizer, TFT5WithLMHeadModel
from transformers import T5Tokenizer, TFT5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = TFT5WithLMHeadModel.from_pretrained('t5-small')
model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids=input_ids)
prediction_scores = outputs[0]
......@@ -729,12 +786,18 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel):
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
encoder_config = copy.deepcopy(config)
self.encoder = TFT5MainLayer(encoder_config, name="encoder")
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
self.decoder = TFT5MainLayer(decoder_config, name="decoder")
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
def get_input_embeddings(self):
return self.shared
......@@ -742,52 +805,67 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel):
def get_output_embeddings(self):
return self.shared
def get_encoder(self):
return self.encoder
def call(self, decoder_input_ids, **kwargs):
# We allow two types of multi-inputs:
# - traditional keyword arguments in the call method
# - all the arguments provided as a dict in the first positional argument of call
# The last option is useful to use the tf.keras fit() method.
if isinstance(decoder_input_ids, dict):
kwargs.update(decoder_input_ids)
else:
kwargs["decoder_input_ids"] = decoder_input_ids
kwargs_common = dict(
(k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_")
)
kwargs_encoder = kwargs_common.copy()
kwargs_decoder = kwargs_common.copy()
kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_")))
kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_")))
# retrieve arguments
input_ids = kwargs.get("input_ids", None)
decoder_input_ids = kwargs.get("decoder_input_ids", None)
attention_mask = kwargs.get("attention_mask", None)
encoder_outputs = kwargs.get("encoder_outputs", None)
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
inputs_embeds = kwargs.get("inputs_embeds", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
head_mask = kwargs.get("head_mask", None)
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
hidden_states = kwargs_encoder.pop("inputs_embeds", None)
if hidden_states is None:
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
encoder_outputs = self.encoder(
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask
)
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
hidden_states = encoder_outputs[0]
# Decode
# Convert decoder inputs in embeddings if needed
hidden_states = kwargs_decoder.pop("inputs_embeds", None)
if hidden_states is None:
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
hidden_states = self.shared(decoder_inputs_ids)
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
)
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
lm_logits = self.shared(sequence_output, mode="linear")
embed_tokens = self.get_output_embeddings()
lm_logits = embed_tokens(sequence_output, mode="linear")
decoder_outputs = (lm_logits,) + decoder_outputs[1:]
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
if type(past) is tuple:
encoder_outputs = past
else:
encoder_outputs = (past,)
return {
"inputs": input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
}
def _reorder_cache(self, past, beam_idx):
# past does not have to be re-ordered for T5.
return past
......@@ -474,6 +474,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
......@@ -586,7 +587,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5WithLMHeadModel`, `TFTransfoXLLMHeadModel`, `TFXLMWithLMHeadModel`)"
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
)
max_length = max_length if max_length is not None else self.config.max_length
......@@ -608,6 +609,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
......@@ -634,6 +636,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
......@@ -703,6 +708,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
attention_mask, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# create empty decoder_input_ids
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
cur_len = 1
else:
encoder_outputs = None
cur_len = shape_list(input_ids)[-1]
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
......@@ -716,13 +740,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
else:
......@@ -737,10 +764,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
......@@ -758,10 +788,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
vocab_size,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example without beam search (num_beams == 1).
......@@ -772,7 +805,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
unfinished_sents = tf.ones_like(input_ids[:, 0])
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
past = None
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
......@@ -859,6 +892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if tf.math.reduce_max(unfinished_sents) == 0:
break
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
cur_len = cur_len + 1
# if there are different sentences lengths in the batch, some batches have to be padded
......@@ -896,13 +935,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example with beam search.
......@@ -923,8 +965,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
# cache compute states
past = None
past = encoder_outputs
# done sentences
done = [False for _ in range(batch_size)]
......@@ -1088,9 +1131,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
# re-order internal states
if past:
if past is not None:
past = self._reorder_cache(past, beam_idx)
if self.config.is_encoder_decoder is False:
attention_mask = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
# update current length
cur_len = cur_len + 1
......
......@@ -806,10 +806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
# TODO: think about how to make this cleaner
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
......@@ -912,20 +909,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
encoder_inputs = input_ids
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# create empty decoder_input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id, # TODO: see whether this is the best result
decoder_start_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
cur_len = 1
else:
encoder_inputs = None
encoder_outputs = None
cur_len = input_ids.shape[-1]
if num_beams > 1:
......@@ -944,12 +948,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_inputs=encoder_inputs,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
else:
......@@ -964,10 +969,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
encoder_inputs=encoder_inputs,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
)
......@@ -985,10 +992,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
encoder_inputs,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example without beam search (num_beams == 1).
......@@ -998,11 +1007,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
......@@ -1099,12 +1107,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
encoder_inputs,
encoder_outputs,
attention_mask,
):
""" Generate sequences for each example with beam search.
......@@ -1125,15 +1134,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = None
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask
)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
......@@ -1152,8 +1159,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False:
# TODO: maybe give better naming
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# TODO (PVP) still a bit hacky here - there might be a better solutino
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
......@@ -1278,7 +1285,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
# re-order internal states
if past:
if past is not None:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input if only decoder
......@@ -1345,8 +1352,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder:
return decoded[:, 1:]
return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.
......
......@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[2],
eos_token_ids=self.eos_token_ids,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
......@@ -234,12 +234,10 @@ class BartHeadTests(unittest.TestCase):
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device)
loss, logits, enc_features = lm_model(
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
)
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
self.assertIsInstance(loss.item(), float)
......@@ -292,7 +290,7 @@ class BartHeadTests(unittest.TestCase):
no_repeat_ngram_size=3,
max_length=max_length,
)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
# TODO(SS): uneven length batches, empty inputs
def test_shift_tokens_right(self):
......
......@@ -147,7 +147,7 @@ class ModelTesterMixin:
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
)
decoder_attention_idx = 1
if "lm_labels" in inputs_dict or "decoder_lm_labels" in inputs_dict: # loss will come first
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
decoder_attention_idx += 1
self.assertEqual(out_len, correct_outlen)
......@@ -601,9 +601,9 @@ class ModelTesterMixin:
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
else:
encoder_input_ids = inputs_dict["encoder_input_ids"]
encoder_input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
del inputs_dict["encoder_input_ids"]
del inputs_dict["input_ids"]
inputs_dict.pop("decoder_input_ids", None)
for model_class in self.all_model_classes:
......@@ -615,7 +615,7 @@ class ModelTesterMixin:
if not self.is_encoder_decoder:
inputs_dict["inputs_embeds"] = wte(input_ids)
else:
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
......@@ -624,9 +624,7 @@ class ModelTesterMixin:
def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get(
"input_ids", None
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
input_ids = inputs_dict.get("input_ids")
if self.is_encoder_decoder:
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
......
......@@ -24,14 +24,15 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import T5Config, T5Model, T5WithLMHeadModel
from transformers import T5Config, T5Model, T5ForConditionalGeneration
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP
@require_torch
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else ()
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
......@@ -56,6 +57,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_ids=[1],
pad_token_id=0,
scope=None,
):
self.parent = parent
......@@ -75,20 +78,22 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.scope = scope
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
def prepare_config_and_inputs(self):
encoder_input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
encoder_attention_mask = None
attention_mask = None
decoder_attention_mask = None
if self.use_attention_mask:
encoder_attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
decoder_lm_labels = None
lm_labels = None
if self.use_labels:
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config(
vocab_size=self.vocab_size,
......@@ -101,41 +106,36 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_ids=self.eos_token_ids,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
)
return (
config,
encoder_input_ids,
input_ids,
decoder_input_ids,
encoder_attention_mask,
attention_mask,
decoder_attention_mask,
decoder_lm_labels,
lm_labels,
)
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_t5_model(
self,
config,
encoder_input_ids,
decoder_input_ids,
encoder_attention_mask,
decoder_attention_mask,
decoder_lm_labels,
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
model.eval()
decoder_output, encoder_output = model(
encoder_input_ids=encoder_input_ids,
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
encoder_attention_mask=encoder_attention_mask,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
decoder_output, encoder_output = model(
encoder_input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids
)
decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
result = {
"encoder_output": encoder_output,
......@@ -149,22 +149,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
)
def create_and_check_t5_with_lm_head(
self,
config,
encoder_input_ids,
decoder_input_ids,
encoder_attention_mask,
decoder_attention_mask,
decoder_lm_labels,
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5WithLMHeadModel(config=config)
model = T5ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
outputs = model(
encoder_input_ids=encoder_input_ids,
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_lm_labels=decoder_lm_labels,
lm_labels=lm_labels,
)
loss, prediction_scores, encoder_features = outputs
self.parent.assertEqual(len(outputs), 3)
......@@ -181,17 +175,18 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
encoder_input_ids,
input_ids,
decoder_input_ids,
encoder_attention_mask,
attention_mask,
decoder_attention_mask,
decoder_lm_labels,
lm_labels,
) = config_and_inputs
inputs_dict = {
"encoder_input_ids": encoder_input_ids,
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
}
return config, inputs_dict
......
......@@ -148,10 +148,12 @@ class TFModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
......@@ -221,7 +223,7 @@ class TFModelTesterMixin:
if self.is_encoder_decoder:
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"encoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
......@@ -393,9 +395,9 @@ class TFModelTesterMixin:
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
else:
encoder_input_ids = inputs_dict["encoder_input_ids"]
encoder_input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
del inputs_dict["encoder_input_ids"]
del inputs_dict["input_ids"]
del inputs_dict["decoder_input_ids"]
for model_class in self.all_model_classes:
......@@ -405,7 +407,7 @@ class TFModelTesterMixin:
if not self.is_encoder_decoder:
inputs_dict["inputs_embeds"] = self._get_embeds(wte, input_ids)
else:
inputs_dict["encoder_inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
inputs_dict["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
model(inputs_dict)
......@@ -413,9 +415,10 @@ class TFModelTesterMixin:
def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get(
"input_ids", None
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
input_ids = inputs_dict["input_ids"]
if self.is_encoder_decoder:
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
for model_class in self.all_generative_model_classes:
model = model_class(config)
......
......@@ -24,14 +24,15 @@ from .utils import CACHE_DIR, require_tf, slow
if is_tf_available():
from transformers.modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel
from transformers.modeling_tf_t5 import TFT5Model, TFT5ForConditionalGeneration
@require_tf
class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else ()
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
class TFT5ModelTester(object):
def __init__(
......@@ -51,6 +52,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_ids=[1],
pad_token_id=0,
scope=None,
):
self.parent = parent
......@@ -68,6 +71,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.scope = scope
def prepare_config_and_inputs(self):
......@@ -92,6 +97,9 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_ids=self.eos_token_ids,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
)
return (config, input_ids, input_mask, token_labels)
......@@ -99,15 +107,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
model = TFT5Model(config=config)
inputs = {
"encoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
encoder_output, decoder_output = model(inputs)
encoder_output, decoder_output = model(
input_ids, decoder_attention_mask=input_mask, encoder_input_ids=input_ids
)
encoder_output, decoder_output = model(input_ids, decoder_attention_mask=input_mask, input_ids=input_ids)
result = {
"encoder_output": encoder_output.numpy(),
......@@ -121,13 +127,15 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
)
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5WithLMHeadModel(config=config)
inputs = {
"encoder_input_ids": input_ids,
model = TFT5ForConditionalGeneration(config=config)
inputs_dict = {
"input_ids": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
prediction_scores, decoder_output = model(inputs)
prediction_scores, decoder_output = model(inputs_dict)
result = {
"prediction_scores": prediction_scores.numpy(),
}
......@@ -139,7 +147,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
inputs_dict = {
"encoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment