"test/vscode:/vscode.git/clone" did not exist on "0a8dc92678f29769b12b1165fc25566ce19f0d50"
Unverified Commit 8b240a06 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFEncoderDecoderModel + Add cross-attention to some TF models (#13222)



* Add cross attentions to TFGPT2Model

* Add TFEncoderDecoderModel

* Add TFBaseModelOutputWithPoolingAndCrossAttentions

* Add cross attentions to TFBertModel

* Fix past or past_key_values argument issue

* Fix generation

* Fix save and load

* Add some checks and comments

* Clean the code that deals with past keys/values

* Add kwargs to processing_inputs

* Add serving_output to TFEncoderDecoderModel

* Some cleaning + fix use_cache value issue

* Fix tests + add bert2bert/bert2gpt2 tests

* Fix more tests

* Ignore crossattention.bias when loading GPT2 weights into TFGPT2

* Fix return_dict_in_generate in tf generation

* Fix is_token_logit_eos_token bug in tf generation

* Finalize the tests after fixing some bugs

* Fix another is_token_logit_eos_token bug in tf generation

* Add/Update docs

* Add TFBertEncoderDecoderModelTest

* Clean test script

* Add TFEncoderDecoderModel to the library

* Add cross attentions to TFRobertaModel

* Add TFRobertaEncoderDecoderModelTest

* make style

* Change the way of position_ids computation

* bug fix

* Fix copies in tf_albert

* Remove some copied from and apply some fix-copies

* Remove some copied

* Add cross attentions to some other TF models

* Remove encoder_hidden_states from TFLayoutLMModel.call for now

* Make style

* Fix TFRemBertForCausalLM

* Revert the change to longformer + Remove copies

* Revert the change to albert and convbert + Remove copies

* make quality

* make style

* Add TFRembertEncoderDecoderModelTest

* make quality and fix-copies

* test TFRobertaForCausalLM

* Fixes for failed tests

* Fixes for failed tests

* fix more tests

* Fixes for failed tests

* Fix Auto mapping order

* Fix TFRemBertEncoder return value

* fix tf_rembert

* Check copies are OK

* Fix missing TFBaseModelOutputWithPastAndCrossAttentions is not defined

* Add TFEncoderDecoderModelSaveLoadTests

* fix tf weight loading

* check the change of use_cache

* Revert the change

* Add missing test_for_causal_lm for TFRobertaModelTest

* Try cleaning past

* fix _reorder_cache

* Revert some files to original versions

* Keep as many copies as possible

* Apply suggested changes - Use raise ValueError instead of assert

* Move import to top

* Fix wrong require_torch

* Replace more assert by raise ValueError

* Add test_pt_tf_model_equivalence (the test won't pass for now)

* add test for loading/saving

* finish

* finish

* Remove test_pt_tf_model_equivalence

* Update tf modeling template

* Remove pooling, added in the prev. commit, from MainLayer

* Update tf modeling test template

* Move inputs["use_cache"] = False to modeling_tf_utils.py

* Fix torch.Tensor in the comment

* fix use_cache

* Fix missing use_cache in ElectraConfig

* Add a note to from_pretrained

* Fix style

* Change test_encoder_decoder_save_load_from_encoder_decoder_from_pt

* Fix TFMLP (in TFGPT2) activation issue

* Fix None past_key_values value in serving_output

* Don't call get_encoderdecoder_model in TFEncoderDecoderModelTest.test_configuration_tie until we have a TF checkpoint on Hub

* Apply review suggestions - style for cross_attns in serving_output

* Apply review suggestions - change assert + docstrings

* break the error message to respect the char limit

* deprecate the argument past

* fix docstring style

* Update the encoder-decoder rst file

* fix Unknown interpreted text role "method"

* fix typo
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 26b6ef79
...@@ -729,7 +729,6 @@ class TFMPNetModel(TFMPNetPreTrainedModel): ...@@ -729,7 +729,6 @@ class TFMPNetModel(TFMPNetPreTrainedModel):
) )
return outputs return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
...@@ -673,7 +673,6 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin ...@@ -673,7 +673,6 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
...@@ -23,15 +23,16 @@ import tensorflow as tf ...@@ -23,15 +23,16 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPoolingAndCrossAttentions,
TFCausalLMOutput, TFCausalLMOutputWithCrossAttentions,
TFMaskedLMOutput, TFMaskedLMOutput,
TFMultipleChoiceModelOutput, TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput, TFQuestionAnsweringModelOutput,
...@@ -112,6 +113,7 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer): ...@@ -112,6 +113,7 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
position_ids: tf.Tensor = None, position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None, inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False, training: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
...@@ -131,7 +133,9 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer): ...@@ -131,7 +133,9 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
...@@ -170,6 +174,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): ...@@ -170,6 +174,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
) )
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
...@@ -182,16 +188,49 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): ...@@ -182,16 +188,49 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states) mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states) # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
...@@ -221,6 +260,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer): ...@@ -221,6 +260,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -259,6 +300,9 @@ class TFRemBertAttention(tf.keras.layers.Layer): ...@@ -259,6 +300,9 @@ class TFRemBertAttention(tf.keras.layers.Layer):
input_tensor: tf.Tensor, input_tensor: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
...@@ -266,13 +310,17 @@ class TFRemBertAttention(tf.keras.layers.Layer): ...@@ -266,13 +310,17 @@ class TFRemBertAttention(tf.keras.layers.Layer):
hidden_states=input_tensor, hidden_states=input_tensor,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = self.dense_output( attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
) )
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them # add attentions (possibly with past_key_value) if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
...@@ -323,6 +371,12 @@ class TFRemBertLayer(tf.keras.layers.Layer): ...@@ -323,6 +371,12 @@ class TFRemBertLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention = TFRemBertAttention(config, name="attention") self.attention = TFRemBertAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TFRemBertAttention(config, name="crossattention")
self.intermediate = TFRemBertIntermediate(config, name="intermediate") self.intermediate = TFRemBertIntermediate(config, name="intermediate")
self.bert_output = TFRemBertOutput(config, name="output") self.bert_output = TFRemBertOutput(config, name="output")
...@@ -331,22 +385,69 @@ class TFRemBertLayer(tf.keras.layers.Layer): ...@@ -331,22 +385,69 @@ class TFRemBertLayer(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_value: Optional[Tuple[tf.Tensor]],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
attention_outputs = self.attention( # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states, input_tensor=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
"by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output) intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output( layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training hidden_states=intermediate_output, input_tensor=attention_output, training=training
) )
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + outputs # add attentions if we output them
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
...@@ -354,6 +455,7 @@ class TFRemBertLayer(tf.keras.layers.Layer): ...@@ -354,6 +455,7 @@ class TFRemBertLayer(tf.keras.layers.Layer):
class TFRemBertEncoder(tf.keras.layers.Layer): class TFRemBertEncoder(tf.keras.layers.Layer):
def __init__(self, config: RemBertConfig, **kwargs): def __init__(self, config: RemBertConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.embedding_hidden_mapping_in = tf.keras.layers.Dense( self.embedding_hidden_mapping_in = tf.keras.layers.Dense(
units=config.hidden_size, units=config.hidden_size,
...@@ -367,40 +469,62 @@ class TFRemBertEncoder(tf.keras.layers.Layer): ...@@ -367,40 +469,62 @@ class TFRemBertEncoder(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_values: Tuple[Tuple[tf.Tensor]],
use_cache: bool,
output_attentions: bool, output_attentions: bool,
output_hidden_states: bool, output_hidden_states: bool,
return_dict: bool, return_dict: bool,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states) hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
all_hidden_states = (hidden_states,) if output_hidden_states else None all_hidden_states = (hidden_states,) if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutput( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -500,6 +624,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -500,6 +624,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.is_decoder = config.is_decoder
self.embeddings = TFRemBertEmbeddings(config, name="embeddings") self.embeddings = TFRemBertEmbeddings(config, name="embeddings")
self.encoder = TFRemBertEncoder(config, name="encoder") self.encoder = TFRemBertEncoder(config, name="encoder")
...@@ -519,6 +644,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -519,6 +644,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call( def call(
self, self,
input_ids: Optional[TFModelInputType] = None, input_ids: Optional[TFModelInputType] = None,
...@@ -527,12 +653,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -527,12 +653,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -542,6 +672,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -542,6 +672,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -549,6 +683,9 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -549,6 +683,9 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
kwargs_call=kwargs, kwargs_call=kwargs,
) )
if not self.config.is_decoder:
inputs["use_cache"] = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif inputs["input_ids"] is not None:
...@@ -558,8 +695,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -558,8 +695,16 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if inputs["past_key_values"] is None:
past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer)
else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
if inputs["attention_mask"] is None: if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1) inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None: if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
...@@ -569,6 +714,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -569,6 +714,7 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
past_key_values_length=past_key_values_length,
training=inputs["training"], training=inputs["training"],
) )
...@@ -577,7 +723,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -577,7 +723,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) attention_mask_shape = shape_list(inputs["attention_mask"])
mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py`
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -589,6 +757,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -589,6 +757,29 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
...@@ -603,6 +794,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -603,6 +794,10 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -613,13 +808,18 @@ class TFRemBertMainLayer(tf.keras.layers.Layer): ...@@ -613,13 +808,18 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
...@@ -632,6 +832,24 @@ class TFRemBertPreTrainedModel(TFPreTrainedModel): ...@@ -632,6 +832,24 @@ class TFRemBertPreTrainedModel(TFPreTrainedModel):
config_class = RemBertConfig config_class = RemBertConfig
base_model_prefix = "rembert" base_model_prefix = "rembert"
@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy
REMBERT_START_DOCSTRING = r""" REMBERT_START_DOCSTRING = r"""
...@@ -740,7 +958,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -740,7 +958,7 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="rembert", checkpoint="rembert",
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -751,12 +969,36 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -751,12 +969,36 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: Optional[bool] = False, training: Optional[bool] = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -766,6 +1008,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -766,6 +1008,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -779,6 +1025,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -779,6 +1025,10 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -787,15 +1037,25 @@ class TFRemBertModel(TFRemBertPreTrainedModel): ...@@ -787,15 +1037,25 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
return outputs return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=output.last_hidden_state, last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output, pooler_output=output.pooler_output,
past_key_values=pkv,
hidden_states=hs, hidden_states=hs,
attentions=attns, attentions=attns,
cross_attentions=cross_attns,
) )
...@@ -912,10 +1172,23 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -912,10 +1172,23 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions return self.mlm.predictions
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="rembert", checkpoint="rembert",
output_type=TFCausalLMOutput, output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -926,14 +1199,36 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -926,14 +1199,36 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False, training: Optional[bool] = False,
**kwargs, **kwargs,
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
r""" r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
...@@ -947,6 +1242,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -947,6 +1242,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -961,6 +1260,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -961,6 +1260,10 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -980,18 +1283,28 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -980,18 +1283,28 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFCausalLMOutput( return TFCausalLMOutputWithCrossAttentions(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) return TFCausalLMOutputWithCrossAttentions(
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@add_start_docstrings( @add_start_docstrings(
......
...@@ -45,6 +45,7 @@ if is_torch_available(): ...@@ -45,6 +45,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_roberta"] = [ _import_structure["modeling_tf_roberta"] = [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRobertaForCausalLM",
"TFRobertaForMaskedLM", "TFRobertaForMaskedLM",
"TFRobertaForMultipleChoice", "TFRobertaForMultipleChoice",
"TFRobertaForQuestionAnswering", "TFRobertaForQuestionAnswering",
...@@ -90,6 +91,7 @@ if TYPE_CHECKING: ...@@ -90,6 +91,7 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
from .modeling_tf_roberta import ( from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
......
...@@ -24,14 +24,16 @@ import tensorflow as tf ...@@ -24,14 +24,16 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPoolingAndCrossAttentions,
TFCausalLMOutputWithCrossAttentions,
TFMaskedLMOutput, TFMaskedLMOutput,
TFMultipleChoiceModelOutput, TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput, TFQuestionAnsweringModelOutput,
...@@ -39,6 +41,7 @@ from ...modeling_tf_outputs import ( ...@@ -39,6 +41,7 @@ from ...modeling_tf_outputs import (
TFTokenClassifierOutput, TFTokenClassifierOutput,
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFModelInputType, TFModelInputType,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
...@@ -112,7 +115,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -112,7 +115,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def create_position_ids_from_input_ids(self, input_ids): def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
...@@ -122,11 +125,19 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -122,11 +125,19 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
Returns: tf.Tensor Returns: tf.Tensor
""" """
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
return incremental_indices + self.padding_idx return incremental_indices + self.padding_idx
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
training=False,
):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -146,7 +157,9 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -146,7 +157,9 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
if position_ids is None: if position_ids is None:
if input_ids is not None: if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded. # Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) position_ids = self.create_position_ids_from_input_ids(
input_ids=input_ids, past_key_values_length=past_key_values_length
)
else: else:
position_ids = tf.expand_dims( position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
...@@ -210,6 +223,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -210,6 +223,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
) )
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
...@@ -222,16 +237,49 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -222,16 +237,49 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states) mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states) # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
...@@ -261,6 +309,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -261,6 +309,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -299,6 +349,9 @@ class TFRobertaAttention(tf.keras.layers.Layer): ...@@ -299,6 +349,9 @@ class TFRobertaAttention(tf.keras.layers.Layer):
input_tensor: tf.Tensor, input_tensor: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
...@@ -306,13 +359,17 @@ class TFRobertaAttention(tf.keras.layers.Layer): ...@@ -306,13 +359,17 @@ class TFRobertaAttention(tf.keras.layers.Layer):
hidden_states=input_tensor, hidden_states=input_tensor,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = self.dense_output( attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
) )
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them # add attentions (possibly with past_key_value) if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
...@@ -363,6 +420,12 @@ class TFRobertaLayer(tf.keras.layers.Layer): ...@@ -363,6 +420,12 @@ class TFRobertaLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention = TFRobertaAttention(config, name="attention") self.attention = TFRobertaAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TFRobertaAttention(config, name="crossattention")
self.intermediate = TFRobertaIntermediate(config, name="intermediate") self.intermediate = TFRobertaIntermediate(config, name="intermediate")
self.bert_output = TFRobertaOutput(config, name="output") self.bert_output = TFRobertaOutput(config, name="output")
...@@ -371,22 +434,69 @@ class TFRobertaLayer(tf.keras.layers.Layer): ...@@ -371,22 +434,69 @@ class TFRobertaLayer(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_value: Optional[Tuple[tf.Tensor]],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
attention_outputs = self.attention( # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states, input_tensor=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
"by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output) intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output( layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training hidden_states=intermediate_output, input_tensor=attention_output, training=training
) )
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + outputs # add attentions if we output them
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
...@@ -395,7 +505,7 @@ class TFRobertaLayer(tf.keras.layers.Layer): ...@@ -395,7 +505,7 @@ class TFRobertaLayer(tf.keras.layers.Layer):
class TFRobertaEncoder(tf.keras.layers.Layer): class TFRobertaEncoder(tf.keras.layers.Layer):
def __init__(self, config: RobertaConfig, **kwargs): def __init__(self, config: RobertaConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call( def call(
...@@ -403,39 +513,61 @@ class TFRobertaEncoder(tf.keras.layers.Layer): ...@@ -403,39 +513,61 @@ class TFRobertaEncoder(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
use_cache: Optional[bool],
output_attentions: bool, output_attentions: bool,
output_hidden_states: bool, output_hidden_states: bool,
return_dict: bool, return_dict: bool,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutput( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -447,6 +579,8 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -447,6 +579,8 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.is_decoder = config.is_decoder
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -483,12 +617,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -483,12 +617,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -498,6 +636,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -498,6 +636,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -505,6 +647,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -505,6 +647,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
kwargs_call=kwargs, kwargs_call=kwargs,
) )
if not self.config.is_decoder:
inputs["use_cache"] = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif inputs["input_ids"] is not None:
...@@ -514,8 +659,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -514,8 +659,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if inputs["past_key_values"] is None:
past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer)
else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
if inputs["attention_mask"] is None: if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1) inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None: if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
...@@ -525,6 +678,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -525,6 +678,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
past_key_values_length=past_key_values_length,
training=inputs["training"], training=inputs["training"],
) )
...@@ -533,7 +687,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -533,7 +687,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) attention_mask_shape = shape_list(inputs["attention_mask"])
mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py`
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -545,6 +721,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -545,6 +721,29 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
...@@ -559,6 +758,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -559,6 +758,10 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -574,11 +777,13 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -574,11 +777,13 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
pooled_output, pooled_output,
) + encoder_outputs[1:] ) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
...@@ -591,6 +796,25 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel): ...@@ -591,6 +796,25 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
config_class = RobertaConfig config_class = RobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
@property
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainedModel.dummy_inputs
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
...@@ -711,7 +935,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -711,7 +935,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -722,12 +946,36 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -722,12 +946,36 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs, **kwargs,
): ):
r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -737,6 +985,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -737,6 +985,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -750,6 +1002,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -750,6 +1002,10 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -759,15 +1015,24 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -759,15 +1015,24 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
return outputs return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=output.last_hidden_state, last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output, pooler_output=output.pooler_output,
past_key_values=pkv,
hidden_states=hs, hidden_states=hs,
attentions=attns, attentions=attns,
cross_attentions=cross_attns,
) )
...@@ -922,6 +1187,163 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -922,6 +1187,163 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"]
def __init__(self, config: RobertaConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if not config.is_decoder:
logger.warning("If you want to use `TFRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.lm_head = TFRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head")
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
**kwargs,
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.lm_head(hidden_states=sequence_output)
loss = None
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFCausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFCausalLMOutputWithCrossAttentions(
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
class TFRobertaClassificationHead(tf.keras.layers.Layer): class TFRobertaClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks.""" """Head for sentence-level classification tasks."""
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from ...file_utils import add_start_docstrings from ...file_utils import add_start_docstrings
from ...utils import logging from ...utils import logging
from ..roberta.modeling_tf_roberta import ( from ..roberta.modeling_tf_roberta import (
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
...@@ -85,6 +86,19 @@ class TFXLMRobertaModel(TFRobertaModel): ...@@ -85,6 +86,19 @@ class TFXLMRobertaModel(TFRobertaModel):
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
@add_start_docstrings(
"XLM-RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.",
XLM_ROBERTA_START_DOCSTRING,
)
class XLMRobertaForCausalLM(TFRobertaForCausalLM):
"""
This class overrides :class:`~transformers.TFRobertaForCausalLM`. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = XLMRobertaConfig
@add_start_docstrings( @add_start_docstrings(
"""XLM-RoBERTa Model with a `language modeling` head on top. """, """XLM-RoBERTa Model with a `language modeling` head on top. """,
XLM_ROBERTA_START_DOCSTRING, XLM_ROBERTA_START_DOCSTRING,
......
...@@ -929,6 +929,15 @@ class TFElectraPreTrainedModel: ...@@ -929,6 +929,15 @@ class TFElectraPreTrainedModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -1712,6 +1721,15 @@ class TFRemBertPreTrainedModel: ...@@ -1712,6 +1721,15 @@ class TFRemBertPreTrainedModel:
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFRobertaForCausalLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFRobertaForMaskedLM: class TFRobertaForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
......
...@@ -24,15 +24,16 @@ import tensorflow as tf ...@@ -24,15 +24,16 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS, MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPoolingAndCrossAttentions,
TFCausalLMOutput, TFCausalLMOutputWithCrossAttentions,
TFMaskedLMOutput, TFMaskedLMOutput,
TFMultipleChoiceModelOutput, TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput, TFQuestionAnsweringModelOutput,
...@@ -116,6 +117,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer): ...@@ -116,6 +117,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
position_ids: tf.Tensor = None, position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None, inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False, training: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
...@@ -135,7 +137,9 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer): ...@@ -135,7 +137,9 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
...@@ -174,6 +178,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) ...@@ -174,6 +178,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
) )
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
...@@ -186,16 +192,49 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) ...@@ -186,16 +192,49 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states) mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states) # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
...@@ -225,6 +264,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) ...@@ -225,6 +264,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -263,6 +304,9 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): ...@@ -263,6 +304,9 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
input_tensor: tf.Tensor, input_tensor: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
...@@ -270,13 +314,17 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): ...@@ -270,13 +314,17 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
hidden_states=input_tensor, hidden_states=input_tensor,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = self.dense_output( attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
) )
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them # add attentions (possibly with past_key_value) if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
...@@ -327,6 +375,12 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer): ...@@ -327,6 +375,12 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="attention") self.attention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="crossattention")
self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate") self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate")
self.bert_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output") self.bert_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output")
...@@ -335,20 +389,69 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer): ...@@ -335,20 +389,69 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_value: Optional[Tuple[tf.Tensor]],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
attention_outputs = self.attention( # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states, input_tensor=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
"by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output) intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(hidden_states=intermediate_output, input_tensor=attention_output, training=training) layer_output = self.bert_output(
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + outputs # add attentions if we output them
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
...@@ -357,7 +460,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer): ...@@ -357,7 +460,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.layer = [TF{{cookiecutter.camelcase_modelname}}Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] self.layer = [TF{{cookiecutter.camelcase_modelname}}Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call( def call(
...@@ -365,39 +468,61 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -365,39 +468,61 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
use_cache: Optional[bool],
output_attentions: bool, output_attentions: bool,
output_hidden_states: bool, output_hidden_states: bool,
return_dict: bool, return_dict: bool,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutput( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -492,6 +617,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -492,6 +617,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.is_decoder = config.is_decoder
self.embeddings = TF{{cookiecutter.camelcase_modelname}}Embeddings(config, name="embeddings") self.embeddings = TF{{cookiecutter.camelcase_modelname}}Embeddings(config, name="embeddings")
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, name="encoder") self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, name="encoder")
...@@ -521,12 +647,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -521,12 +647,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -536,6 +666,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -536,6 +666,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -543,6 +677,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -543,6 +677,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
kwargs_call=kwargs, kwargs_call=kwargs,
) )
if not self.config.is_decoder:
inputs["use_cache"] = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif inputs["input_ids"] is not None:
...@@ -552,8 +689,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -552,8 +689,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if inputs["past_key_values"] is None:
past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer)
else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
if inputs["attention_mask"] is None: if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1) inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None: if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
...@@ -563,6 +708,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -563,6 +708,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
past_key_values_length=past_key_values_length,
training=inputs["training"], training=inputs["training"],
) )
...@@ -571,7 +717,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -571,7 +717,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) attention_mask_shape = shape_list(inputs["attention_mask"])
mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py`
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -583,6 +751,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -583,6 +751,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
...@@ -597,6 +788,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -597,6 +788,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -610,10 +805,12 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -610,10 +805,12 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
sequence_output, sequence_output,
) + encoder_outputs[1:] ) + encoder_outputs[1:]
return TFBaseModelOutput( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
...@@ -625,6 +822,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel): ...@@ -625,6 +822,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "{{cookiecutter.lowercase_modelname}}" base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
...@@ -732,7 +947,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -732,7 +947,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -743,12 +958,36 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -743,12 +958,36 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: Optional[bool] = False, training: Optional[bool] = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -758,6 +997,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -758,6 +997,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -771,6 +1014,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -771,6 +1014,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -779,12 +1026,26 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -779,12 +1026,26 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
return outputs return outputs
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output: TFBaseModelOutput) -> TFBaseModelOutput: def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns) return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
past_key_values=pkv,
hidden_states=hs,
attentions=attns,
cross_attentions=cross_attns,
)
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING) @add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
...@@ -903,10 +1164,22 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -903,10 +1164,22 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions return self.mlm.predictions
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFCausalLMOutput, output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -917,14 +1190,36 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -917,14 +1190,36 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False, training: Optional[bool] = False,
**kwargs, **kwargs,
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
r""" r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
...@@ -938,6 +1233,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -938,6 +1233,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -952,6 +1251,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -952,6 +1251,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -971,19 +1274,28 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -971,19 +1274,28 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFCausalLMOutput( return TFCausalLMOutputWithCrossAttentions(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput: def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) return TFCausalLMOutputWithCrossAttentions(
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer): class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
......
...@@ -21,7 +21,7 @@ from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Co ...@@ -21,7 +21,7 @@ from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Co
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available(): if is_tf_available():
...@@ -123,6 +123,33 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester: ...@@ -123,6 +123,33 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -21,7 +21,7 @@ from transformers.models.auto import get_values ...@@ -21,7 +21,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available(): if is_tf_available():
...@@ -125,6 +125,33 @@ class TFBertModelTester: ...@@ -125,6 +125,33 @@ class TFBertModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_bert_model( def create_and_check_bert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -1393,6 +1393,22 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): ...@@ -1393,6 +1393,22 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
return output return output
def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return tf.reshape(tf.constant(values, dtype=dtype if dtype is not None else tf.float32), shape=shape)
@require_tf @require_tf
class UtilsFunctionsTest(unittest.TestCase): class UtilsFunctionsTest(unittest.TestCase):
......
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from transformers import is_tf_available, is_torch_available
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_torch, slow, torch_device
from .test_modeling_tf_bert import TFBertModelTester
from .test_modeling_tf_common import ids_tensor
from .test_modeling_tf_rembert import TFRemBertModelTester
from .test_modeling_tf_roberta import TFRobertaModelTester
if is_tf_available():
from transformers import (
AutoConfig,
AutoTokenizer,
EncoderDecoderConfig,
TFAutoModel,
TFAutoModelForCausalLM,
TFBertLMHeadModel,
TFBertModel,
TFEncoderDecoderModel,
TFRemBertForCausalLM,
TFRemBertModel,
TFRobertaForCausalLM,
TFRobertaModel,
)
from transformers.modeling_tf_outputs import TFBaseModelOutput
if is_torch_available():
import torch
from transformers import BertLMHeadModel, BertModel, EncoderDecoderModel
@require_tf
class TFEncoderDecoderMixin:
def get_encoder_decoder_model(self, config, decoder_config):
raise NotImplementedError
def prepare_config_and_inputs(self):
raise NotImplementedError
def get_pretrained_model(self):
raise NotImplementedError
def check_encoder_decoder_model_from_pretrained_configs(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
enc_dec_model = TFEncoderDecoderModel(encoder_decoder_config)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_hidden_states)
outputs_encoder_decoder = enc_dec_model(
input_ids=None,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_from_pretrained(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
return_dict,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
enc_dec_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_save_and_load(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
enc_dec_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_encoder_decoder_model_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
)
# Make sure `loss` exist
assert "loss" in outputs_encoder_decoder
batch_size, seq_len = decoder_input_ids.shape
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size)
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertEqual(
encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
)
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
)
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
)
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**input_ids_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
def test_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
def test_encoder_decoder_model_from_pretrained_return_dict(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict)
def test_encoder_decoder_model_labels(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_labels(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
outputs = model_2(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmp_dirname:
model_2.save_pretrained(tmp_dirname)
model_1 = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
after_outputs = model_1(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_tf
class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder")
decoder_model = TFBertLMHeadModel(decoder_config, name="decoder")
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFBertModelTester(self, batch_size=13)
model_tester_decoder = TFBertModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
attention_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_token_type_ids,
decoder_attention_mask,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_attention_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
@slow
@is_pt_tf_cross_test
def test_bert2bert_summarization(self):
from transformers import EncoderDecoderModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
"""Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.encoder.layer...`
(For Bert decoder, there is no issue, because `BertModel` is wrapped into `decoder` as `bert`)
model = TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16", from_pt=True)
"""
# workaround to load from pt
_model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
_model.encoder.save_pretrained("./encoder")
_model.decoder.save_pretrained("./decoder")
model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
)
model.config = _model.config
ARTICLE_STUDENTS = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents."""
EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""
input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
@require_tf
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRobertaModel(config, name="encoder")
decoder_model = TFRobertaForCausalLM(decoder_config, name="decoder")
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFRobertaModelTester(self)
model_tester_decoder = TFRobertaModelTester(self)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_token_type_ids,
decoder_input_mask,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
@require_tf
class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("google/rembert", "google/rembert")
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRemBertModel(config, name="encoder")
decoder_model = TFRemBertForCausalLM(decoder_config, name="decoder")
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFRemBertModelTester(self)
model_tester_decoder = TFRemBertModelTester(self)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_token_type_ids,
decoder_input_mask,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
@require_tf
class TFEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("bert-base-cased")
config.is_decoder = True
config.add_cross_attention = True
return config
def get_encoderdecoder_model(self):
return TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
def get_encoder_decoder_models(self):
encoder_model = TFBertModel.from_pretrained("bert-base-cased", name="encoder")
decoder_model = TFBertLMHeadModel.from_pretrained(
"bert-base-cased", config=self.get_decoder_config(), name="decoder"
)
return {"encoder": encoder_model, "decoder": decoder_model}
def _check_configuration_tie(self, model):
assert id(model.decoder.config) == id(model.config.decoder)
assert id(model.encoder.config) == id(model.config.encoder)
@slow
def test_configuration_tie(self):
model = self.get_from_encoderdecoder_pretrained_model()
self._check_configuration_tie(model)
model = TFEncoderDecoderModel(**self.get_encoder_decoder_models())
self._check_configuration_tie(model)
# # This should be enabled once we upload the TF version of
# # "patrickvonplaten/bert2bert-cnn_dailymail-fp16" to the Hub.
# model = self.get_encoderdecoder_model()
# self._check_configuration_tie(model)
@require_tf
class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
def get_encoder_decoder_config(self):
encoder_config = AutoConfig.from_pretrained("bert-base-uncased")
decoder_config = AutoConfig.from_pretrained("bert-base-uncased", is_decoder=True, add_cross_attention=True)
return EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
def get_encoder_decoder_config_small(self):
encoder_config = AutoConfig.from_pretrained("hf-internal-testing/tiny-bert")
decoder_config = AutoConfig.from_pretrained(
"hf-internal-testing/tiny-bert", is_decoder=True, add_cross_attention=True
)
return EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
def test_encoder_decoder_save_load_from_encoder_decoder(self):
config = self.get_encoder_decoder_config_small()
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
encoder = TFBertModel(config.encoder)
encoder(encoder.dummy_inputs)
decoder = TFBertLMHeadModel(config.decoder)
decoder(decoder.dummy_inputs)
encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
input_ids = ids_tensor([13, 5], encoder.config.vocab_size)
decoder_input_ids = ids_tensor([13, 1], decoder.config.vocab_size)
logits_orig = encoder_decoder_orig(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_path = os.path.join(tmp_dirname, "encoder")
decoder_path = os.path.join(tmp_dirname, "decoder")
encoder.save_pretrained(encoder_path)
decoder.save_pretrained(decoder_path)
encoder_decoder = TFEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_path, decoder_path)
logits_1 = encoder_decoder(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
self.assertTrue(logits_orig.numpy().sum() - logits_1.numpy().sum() < 1e-3)
max_diff = np.max(np.abs(logits_1.numpy() - logits_orig.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=4)
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder.save_pretrained(tmp_dirname)
encoder_decoder = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
logits_2 = encoder_decoder(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
max_diff = np.max(np.abs(logits_2.numpy() - logits_orig.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=4)
@require_torch
@is_pt_tf_cross_test
def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
config = self.get_encoder_decoder_config_small()
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
encoder_pt = BertModel(config.encoder).to(torch_device).eval()
decoder_pt = BertLMHeadModel(config.decoder).to(torch_device).eval()
encoder_decoder_pt = EncoderDecoderModel(encoder=encoder_pt, decoder=decoder_pt).to(torch_device).eval()
input_ids = ids_tensor([13, 5], encoder_pt.config.vocab_size)
decoder_input_ids = ids_tensor([13, 1], decoder_pt.config.vocab_size)
pt_input_ids = torch.tensor(input_ids.numpy(), device=torch_device, dtype=torch.long)
pt_decoder_input_ids = torch.tensor(decoder_input_ids.numpy(), device=torch_device, dtype=torch.long)
logits_pt = encoder_decoder_pt(input_ids=pt_input_ids, decoder_input_ids=pt_decoder_input_ids).logits
# PyTorch => TensorFlow
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
)
logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)
# TensorFlow => PyTorch
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)
@slow
def test_encoder_decoder_from_pretrained(self):
load_weight_prefix = "tf_encoder_decoder_model_1"
config = self.get_encoder_decoder_config()
encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
decoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
input_ids = encoder_tokenizer("who sings does he love me with reba", return_tensors="tf").input_ids
decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids
with tempfile.TemporaryDirectory() as tmp_dirname:
# Since most of HF's models don't have pretrained cross-attention layers, they are randomly
# initialized even if we create models using `from_pretrained` method.
# For the tests, the decoder need to be a model with pretrained cross-attention layers.
# So we create pretrained models (without `load_weight_prefix`), save them, and later,
# we load them using `from_pretrained`.
# (we don't need to do this for encoder, but let's make the code more similar between encoder/decoder)
encoder = TFAutoModel.from_pretrained("bert-base-uncased", name="encoder")
# It's necessary to specify `add_cross_attention=True` here.
decoder = TFAutoModelForCausalLM.from_pretrained(
"bert-base-uncased", is_decoder=True, add_cross_attention=True, name="decoder"
)
pretrained_encoder_dir = os.path.join(tmp_dirname, "pretrained_encoder")
pretrained_decoder_dir = os.path.join(tmp_dirname, "pretrained_decoder")
encoder.save_pretrained(pretrained_encoder_dir)
decoder.save_pretrained(pretrained_decoder_dir)
del encoder
del decoder
enc_dec_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
pretrained_encoder_dir,
pretrained_decoder_dir,
)
# check that the from pretrained methods work
enc_dec_model.save_pretrained(tmp_dirname)
enc_dec_model = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
output = enc_dec_model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
loss_pretrained = output.loss
del enc_dec_model
# Create the model using `__init__` with loaded ``pretrained`` encoder / decoder
encoder = TFAutoModel.from_pretrained(
pretrained_encoder_dir, load_weight_prefix=load_weight_prefix, name="encoder"
)
decoder = TFAutoModelForCausalLM.from_pretrained(
pretrained_decoder_dir, load_weight_prefix=load_weight_prefix, name="decoder"
)
enc_dec_model = TFEncoderDecoderModel(config=config, encoder=encoder, decoder=decoder)
output = enc_dec_model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
loss_init = output.loss
max_diff = np.max(np.abs(loss_pretrained - loss_init))
expected_diff = 0.0
self.assertAlmostEqual(max_diff, expected_diff, places=4)
...@@ -20,7 +20,7 @@ from transformers import RemBertConfig, is_tf_available ...@@ -20,7 +20,7 @@ from transformers import RemBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available(): if is_tf_available():
...@@ -131,6 +131,33 @@ class TFRemBertModelTester: ...@@ -131,6 +131,33 @@ class TFRemBertModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
......
...@@ -20,7 +20,7 @@ from transformers import RobertaConfig, is_tf_available ...@@ -20,7 +20,7 @@ from transformers import RobertaConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available(): if is_tf_available():
...@@ -29,6 +29,7 @@ if is_tf_available(): ...@@ -29,6 +29,7 @@ if is_tf_available():
from transformers.models.roberta.modeling_tf_roberta import ( from transformers.models.roberta.modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
...@@ -101,6 +102,33 @@ class TFRobertaModelTester: ...@@ -101,6 +102,33 @@ class TFRobertaModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_roberta_model( def create_and_check_roberta_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -115,6 +143,13 @@ class TFRobertaModelTester: ...@@ -115,6 +143,13 @@ class TFRobertaModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_roberta_for_causal_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForCausalLM(config=config)
result = model([input_ids, input_mask, token_type_ids])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_roberta_for_masked_lm( def create_and_check_roberta_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -177,6 +212,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -177,6 +212,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
TFRobertaModel, TFRobertaModel,
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TFRobertaForTokenClassification, TFRobertaForTokenClassification,
...@@ -203,6 +239,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -203,6 +239,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_causal_lm(*config_and_inputs)
def test_for_token_classification(self): def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs)
......
...@@ -160,6 +160,7 @@ def get_model_modules(): ...@@ -160,6 +160,7 @@ def get_model_modules():
"modeling_flax_utils", "modeling_flax_utils",
"modeling_transfo_xl_utilities", "modeling_transfo_xl_utilities",
"modeling_tf_auto", "modeling_tf_auto",
"modeling_tf_encoder_decoder",
"modeling_tf_outputs", "modeling_tf_outputs",
"modeling_tf_pytorch_utils", "modeling_tf_pytorch_utils",
"modeling_tf_utils", "modeling_tf_utils",
...@@ -231,6 +232,7 @@ def get_model_test_files(): ...@@ -231,6 +232,7 @@ def get_model_test_files():
"test_modeling_flax_encoder_decoder", "test_modeling_flax_encoder_decoder",
"test_modeling_marian", "test_modeling_marian",
"test_modeling_tf_common", "test_modeling_tf_common",
"test_modeling_tf_encoder_decoder",
] ]
test_files = [] test_files = []
for filename in os.listdir(PATH_TO_TESTS): for filename in os.listdir(PATH_TO_TESTS):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment