Unverified Commit bd21ed40 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add cross attentions to TFGPT2Model (#14038)



* Add cross attentions to TFGPT2Model

* change to is_pt_tf_cross_test

* A minor correction to a comment

* Remove n_ctx when creating self.crossattention
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 5f789a68
...@@ -597,7 +597,6 @@ class TFCTRLModel(TFCTRLPreTrainedModel): ...@@ -597,7 +597,6 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
) )
return outputs return outputs
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2Model.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
...@@ -754,7 +753,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -754,7 +753,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.serving_output
def serving_output(self, output): def serving_output(self, output):
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
......
...@@ -22,6 +22,7 @@ import tensorflow as tf ...@@ -22,6 +22,7 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS,
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -29,8 +30,8 @@ from ...file_utils import ( ...@@ -29,8 +30,8 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFCausalLMOutputWithPast, TFCausalLMOutputWithCrossAttentions,
TFSequenceClassifierOutputWithPast, TFSequenceClassifierOutputWithPast,
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
...@@ -66,7 +67,7 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -66,7 +67,7 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFAttention(tf.keras.layers.Layer): class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, config, scale=False, **kwargs): def __init__(self, nx, config, scale=False, is_cross_attention=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
...@@ -77,7 +78,14 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -77,7 +78,14 @@ class TFAttention(tf.keras.layers.Layer):
self.scale = scale self.scale = scale
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.is_cross_attention = is_cross_attention
if self.is_cross_attention:
self.c_attn = TFConv1D(n_state * 2, nx, initializer_range=config.initializer_range, name="c_attn")
self.q_attn = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="q_attn")
else:
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop) self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop) self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
...@@ -104,6 +112,9 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -104,6 +112,9 @@ class TFAttention(tf.keras.layers.Layer):
dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
w = w / tf.math.sqrt(dk) w = w / tf.math.sqrt(dk)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
_, _, nd, ns = shape_list(w) _, _, nd, ns = shape_list(w)
b = self.causal_attention_mask(nd, ns, dtype=w.dtype) b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
...@@ -139,9 +150,34 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -139,9 +150,34 @@ class TFAttention(tf.keras.layers.Layer):
x = tf.reshape(x, new_x_shape) x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False): def call(
self,
x,
layer_past,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
training=False,
):
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(x)
kv_out = self.c_attn(encoder_hidden_states)
key, value = tf.split(kv_out, 2, axis=2)
attention_mask = encoder_attention_mask
else:
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2) query, key, value = tf.split(x, 3, axis=2)
query = self.split_heads(query) query = self.split_heads(query)
key = self.split_heads(key) key = self.split_heads(key)
value = self.split_heads(value) value = self.split_heads(value)
...@@ -191,22 +227,75 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -191,22 +227,75 @@ class TFBlock(tf.keras.layers.Layer):
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
self.attn = TFAttention(nx, config, scale, name="attn") self.attn = TFAttention(nx, config, scale, name="attn")
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
if config.add_cross_attention:
self.crossattention = TFAttention(nx, config, scale, name="crossattention", is_cross_attention=True)
self.ln_cross_attn = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_epsilon, name="ln_cross_attn"
)
self.mlp = TFMLP(inner_dim, config, name="mlp") self.mlp = TFMLP(inner_dim, config, name="mlp")
def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False): def call(
self,
x,
layer_past,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
training=False,
):
a = self.ln_1(x) a = self.ln_1(x)
output_attn = self.attn( output_attn = self.attn(
a, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=training a,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
) )
a = output_attn[0] # output_attn: a, present, (attentions) a = output_attn[0] # output_attn: a, present, (attentions)
outputs = output_attn[1:]
x = x + a x = x + a
# Cross-Attention Block
if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
ca = self.ln_cross_attn(x)
output_cross_attn = self.crossattention(
ca,
layer_past=None,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=False,
output_attentions=output_attentions,
training=training,
)
ca = output_cross_attn[0] # output_attn: a, present, (cross_attentions)
x = x + ca
outputs = outputs + output_cross_attn[2:] # add cross attentions if we output attention weights
m = self.ln_2(x) m = self.ln_2(x)
m = self.mlp(m, training=training) m = self.mlp(m, training=training)
x = x + m x = x + m
outputs = [x] + output_attn[1:] outputs = [x] + outputs
return outputs # x, present, (attentions) return outputs # x, present, (attentions, cross_attentions)
@keras_serializable @keras_serializable
...@@ -267,6 +356,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -267,6 +356,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -284,6 +375,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -284,6 +375,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -333,6 +426,31 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -333,6 +426,31 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0) tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
) )
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.config.add_cross_attention and inputs["encoder_attention_mask"] is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=inputs["encoder_hidden_states"].dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
inputs["encoder_attention_mask"] = encoder_extended_attention_mask
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
...@@ -368,6 +486,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -368,6 +486,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
presents = () if inputs["use_cache"] else None presents = () if inputs["use_cache"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
all_cross_attentions = () if inputs["output_attentions"] and self.config.add_cross_attention else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])): for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -378,6 +497,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -378,6 +497,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
layer_past, layer_past,
inputs["attention_mask"], inputs["attention_mask"],
inputs["head_mask"][i], inputs["head_mask"][i],
inputs["encoder_hidden_states"],
inputs["encoder_attention_mask"],
inputs["use_cache"], inputs["use_cache"],
inputs["output_attentions"], inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
...@@ -389,6 +510,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -389,6 +510,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions = all_attentions + (outputs[2],) all_attentions = all_attentions + (outputs[2],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (outputs[3],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -403,13 +526,18 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -403,13 +526,18 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions, all_cross_attentions]
if v is not None
)
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=presents, past_key_values=presents,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -422,7 +550,25 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): ...@@ -422,7 +550,25 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
config_class = GPT2Config config_class = GPT2Config
base_model_prefix = "transformer" base_model_prefix = "transformer"
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias"] _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"]
@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy
@tf.function( @tf.function(
input_signature=[ input_signature=[
...@@ -588,7 +734,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -588,7 +734,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPast, output_type=TFBaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -600,6 +746,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -600,6 +746,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -607,6 +755,26 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -607,6 +755,26 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that
don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all
:obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past`). Set to :obj:`False` during training, :obj:`True` during generation
"""
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -617,6 +785,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -617,6 +785,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -632,6 +802,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -632,6 +802,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -645,9 +817,20 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -645,9 +817,20 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = (
tf.convert_to_tensor(output.cross_attentions)
if self.config.output_attentions
and self.config.add_cross_attention
and output.cross_attentions is not None
else None
)
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=output.last_hidden_state, past_key_values=pkv, hidden_states=hs, attentions=attns last_hidden_state=output.last_hidden_state,
past_key_values=pkv,
hidden_states=hs,
attentions=attns,
cross_attentions=cross_attns,
) )
...@@ -680,7 +863,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -680,7 +863,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFCausalLMOutputWithPast, output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -692,6 +875,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -692,6 +875,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -701,6 +886,24 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -701,6 +886,24 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
**kwargs, **kwargs,
): ):
r""" r"""
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that
don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all
:obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past`). Set to :obj:`False` during training, :obj:`True` during generation
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
...@@ -715,6 +918,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -715,6 +918,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -731,6 +936,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -731,6 +936,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
position_ids=inputs["position_ids"], position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -751,20 +958,30 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -751,20 +958,30 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFCausalLMOutputWithPast( return TFCausalLMOutputWithCrossAttentions(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
) )
def serving_output(self, output): def serving_output(self, output):
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = (
tf.convert_to_tensor(output.cross_attentions)
if self.config.output_attentions
and self.config.add_cross_attention
and output.cross_attentions is not None
else None
)
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns) return TFCausalLMOutputWithCrossAttentions(
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@add_start_docstrings( @add_start_docstrings(
...@@ -871,16 +1088,18 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -871,16 +1088,18 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
) )
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, input_ids=flat_input_ids,
inputs["past"], past=inputs["past"],
flat_attention_mask, attention_mask=flat_attention_mask,
flat_token_type_ids, token_type_ids=flat_token_type_ids,
flat_position_ids, position_ids=flat_position_ids,
inputs["head_mask"], head_mask=inputs["head_mask"],
inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
inputs["use_cache"], encoder_hidden_states=None,
inputs["output_attentions"], encoder_attention_mask=None,
inputs["output_hidden_states"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
......
...@@ -25,6 +25,7 @@ from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_ ...@@ -25,6 +25,7 @@ from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_
from .test_modeling_tf_bert import TFBertModelTester from .test_modeling_tf_bert import TFBertModelTester
from .test_modeling_tf_common import ids_tensor from .test_modeling_tf_common import ids_tensor
from .test_modeling_tf_gpt2 import TFGPT2ModelTester
from .test_modeling_tf_rembert import TFRemBertModelTester from .test_modeling_tf_rembert import TFRemBertModelTester
from .test_modeling_tf_roberta import TFRobertaModelTester from .test_modeling_tf_roberta import TFRobertaModelTester
...@@ -39,6 +40,7 @@ if is_tf_available(): ...@@ -39,6 +40,7 @@ if is_tf_available():
TFBertLMHeadModel, TFBertLMHeadModel,
TFBertModel, TFBertModel,
TFEncoderDecoderModel, TFEncoderDecoderModel,
TFGPT2LMHeadModel,
TFRemBertForCausalLM, TFRemBertForCausalLM,
TFRemBertModel, TFRemBertModel,
TFRobertaForCausalLM, TFRobertaForCausalLM,
...@@ -432,7 +434,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): ...@@ -432,7 +434,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
"""Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.encoder.layer...` """Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.encoder.layer...`.
(For Bert decoder, there is no issue, because `BertModel` is wrapped into `decoder` as `bert`) (For Bert decoder, there is no issue, because `BertModel` is wrapped into `decoder` as `bert`)
model = TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16", from_pt=True) model = TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16", from_pt=True)
""" """
...@@ -456,6 +458,95 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): ...@@ -456,6 +458,95 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS]) self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
@require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder")
decoder_model = TFGPT2LMHeadModel(decoder_config, name="decoder")
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFBertModelTester(self, batch_size=13)
model_tester_decoder = TFGPT2ModelTester(self)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
attention_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
decoder_head_mask,
decoder_token_type_ids,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_attention_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
@slow
@is_pt_tf_cross_test
def test_bert2gpt2_summarization(self):
from transformers import EncoderDecoderModel
tokenizer_in = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer_out = AutoTokenizer.from_pretrained("gpt2")
"""Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.encoder.layer...`.
(For GPT2 decoder, there is no issue)
model = TFEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16", from_pt=True)
"""
# workaround to load from pt
_model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
_model.encoder.save_pretrained("./encoder")
_model.decoder.save_pretrained("./decoder")
model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
)
model.config = _model.config
ARTICLE_STUDENTS = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents."""
EXPECTED_SUMMARY_STUDENTS = """SAS Alpha Epsilon suspended the students, but university president says it's permanent.\nThe fraternity has had to deal with a string of student deaths since 2010.\nSAS has more than 200,000 members, many of whom are students.\nA student died while being forced into excessive alcohol consumption."""
input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
@require_tf @require_tf
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model(self):
......
...@@ -19,7 +19,7 @@ from transformers import GPT2Config, is_tf_available ...@@ -19,7 +19,7 @@ from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available(): if is_tf_available():
...@@ -122,6 +122,35 @@ class TFGPT2ModelTester: ...@@ -122,6 +122,35 @@ class TFGPT2ModelTester:
choice_labels, choice_labels,
) )
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFGPT2Model(config=config) model = TFGPT2Model(config=config)
inputs = { inputs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment