Unverified Commit 6e603cb7 authored by Bharat Raghunathan's avatar Bharat Raghunathan Committed by GitHub
Browse files

[All models] Extend config.output_attentions with output_attentions function arguments (#4538)



* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions``

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* Fix further regressions in tests relating to `output_attentions`

Ensure proper propagation of `output_attentions` as a function parameter
to all model subclasses

* Fix more regressions in `test_output_attentions`

* Fix issues with BertEncoder

* Rename related variables to `output_attentions`

* fix pytorch tests

* fix bert and gpt2 tf

* Fix most TF tests for `test_output_attentions`

* Fix linter errors and more TF tests

* fix conflicts

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix pytorch tests

* fix conflicts

* fix conflicts

* Fix linter errors and more TF tests

* fix tf tests

* make style

* fix isort

* improve output_attentions

* improve tensorflow
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f90bc44d
...@@ -134,6 +134,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -134,6 +134,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
training=False, training=False,
output_attentions=False,
): ):
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -255,7 +256,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -255,7 +256,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training) attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training)
attn = attn_outputs[0] attn = attn_outputs[0]
if self.output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -266,7 +267,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -266,7 +267,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training [tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if self.output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -302,7 +303,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -302,7 +303,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
outputs = (tensor,) outputs = (tensor,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions) return outputs # outputs, (hidden_states), (attentions)
......
...@@ -28,6 +28,7 @@ from .modeling_tf_utils import ( ...@@ -28,6 +28,7 @@ from .modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -63,7 +64,6 @@ def gelu(x): ...@@ -63,7 +64,6 @@ def gelu(x):
class TFAttention(tf.keras.layers.Layer): class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False, **kwargs): def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -93,7 +93,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -93,7 +93,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.cast(m, dtype) return tf.cast(m, dtype)
def _attn(self, inputs, training=False): def _attn(self, inputs, training=False):
q, k, v, attention_mask, head_mask = inputs q, k, v, attention_mask, head_mask, output_attentions = inputs
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True) w = tf.matmul(q, k, transpose_b=True)
if self.scale: if self.scale:
...@@ -118,7 +118,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -118,7 +118,7 @@ class TFAttention(tf.keras.layers.Layer):
w = w * head_mask w = w * head_mask
outputs = [tf.matmul(w, v)] outputs = [tf.matmul(w, v)]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs.append(w) outputs.append(w)
return outputs return outputs
...@@ -135,7 +135,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -135,7 +135,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask, use_cache = inputs x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2) query, key, value = tf.split(x, 3, axis=2)
...@@ -148,20 +148,12 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -148,20 +148,12 @@ class TFAttention(tf.keras.layers.Layer):
value = tf.concat([past_value, value], axis=-2) value = tf.concat([past_value, value], axis=-2)
# to cope with keras serialization # to cope with keras serialization
# we need to cast `use_cache` to correct bool if cast_bool_to_primitive(use_cache, True) is True:
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True:
present = tf.stack([key, value], axis=0) present = tf.stack([key, value], axis=0)
else: else:
present = (None,) present = (None,)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training) attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
a = attn_outputs[0] a = attn_outputs[0]
a = self.merge_heads(a) a = self.merge_heads(a)
...@@ -198,10 +190,12 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -198,10 +190,12 @@ class TFBlock(tf.keras.layers.Layer):
self.mlp = TFMLP(4 * nx, config, name="mlp") self.mlp = TFMLP(4 * nx, config, name="mlp")
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask, use_cache = inputs x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
a = self.ln_1(x) a = self.ln_1(x)
output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training) output_attn = self.attn(
[a, layer_past, attention_mask, head_mask, use_cache, output_attentions], training=training
)
a = output_attn[0] # output_attn: a, present, (attentions) a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a x = x + a
...@@ -219,8 +213,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -219,8 +213,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.num_hidden_layers = config.n_layer self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.n_embd = config.n_embd self.n_embd = config.n_embd
...@@ -261,6 +255,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -261,6 +255,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
training=False, training=False,
output_attentions=None,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
...@@ -271,7 +266,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -271,7 +266,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs." output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
...@@ -281,10 +277,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -281,10 +277,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -355,12 +354,15 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -355,12 +354,15 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training) outputs = block(
[hidden_states, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
training=training,
)
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -376,7 +378,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -376,7 +378,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
outputs = outputs + (presents,) outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
...@@ -493,7 +495,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -493,7 +495,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -552,7 +554,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): ...@@ -552,7 +554,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -614,6 +616,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -614,6 +616,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
mc_token_ids=None, mc_token_ids=None,
use_cache=True, use_cache=True,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -636,7 +639,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -636,7 +639,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -681,7 +684,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -681,7 +684,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
use_cache = inputs[8] if len(inputs) > 8 else use_cache use_cache = inputs[8] if len(inputs) > 8 else use_cache
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs[9] if len(inputs) > 8 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
...@@ -692,7 +696,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -692,7 +696,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids) mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -717,6 +722,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -717,6 +722,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask, head_mask,
inputs_embeds, inputs_embeds,
use_cache, use_cache,
output_attentions,
] ]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
......
...@@ -28,6 +28,7 @@ from .modeling_tf_utils import ( ...@@ -28,6 +28,7 @@ from .modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -70,7 +71,6 @@ ACT_FNS = { ...@@ -70,7 +71,6 @@ ACT_FNS = {
class TFAttention(tf.keras.layers.Layer): class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False, **kwargs): def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -100,7 +100,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -100,7 +100,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.cast(m, dtype) return tf.cast(m, dtype)
def _attn(self, inputs, training=False): def _attn(self, inputs, training=False):
q, k, v, attention_mask, head_mask = inputs q, k, v, attention_mask, head_mask, output_attentions = inputs
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True) w = tf.matmul(q, k, transpose_b=True)
if self.scale: if self.scale:
...@@ -125,7 +125,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -125,7 +125,7 @@ class TFAttention(tf.keras.layers.Layer):
w = w * head_mask w = w * head_mask
outputs = [tf.matmul(w, v)] outputs = [tf.matmul(w, v)]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs.append(w) outputs.append(w)
return outputs return outputs
...@@ -142,7 +142,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -142,7 +142,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, attention_mask, head_mask = inputs x, attention_mask, head_mask, output_attentions = inputs
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2) query, key, value = tf.split(x, 3, axis=2)
...@@ -150,7 +150,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -150,7 +150,7 @@ class TFAttention(tf.keras.layers.Layer):
key = self.split_heads(key) key = self.split_heads(key)
value = self.split_heads(value) value = self.split_heads(value)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training) attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
a = attn_outputs[0] a = attn_outputs[0]
a = self.merge_heads(a) a = self.merge_heads(a)
...@@ -187,9 +187,9 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -187,9 +187,9 @@ class TFBlock(tf.keras.layers.Layer):
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, attention_mask, head_mask = inputs x, attention_mask, head_mask, output_attentions = inputs
output_attn = self.attn([x, attention_mask, head_mask], training=training) output_attn = self.attn([x, attention_mask, head_mask, output_attentions], training=training)
a = output_attn[0] # output_attn: a, (attentions) a = output_attn[0] # output_attn: a, (attentions)
n = self.ln_1(x + a) n = self.ln_1(x + a)
...@@ -244,6 +244,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -244,6 +244,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -253,7 +254,8 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -253,7 +254,8 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -261,10 +263,13 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -261,10 +263,13 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -329,9 +334,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -329,9 +334,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, attention_mask, head_mask[i]], training=training) outputs = block([hidden_states, attention_mask, head_mask[i], output_attentions], training=training)
hidden_states = outputs[0] hidden_states = outputs[0]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
all_attentions.append(outputs[1]) all_attentions.append(outputs[1])
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
...@@ -342,7 +347,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -342,7 +347,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
...@@ -448,7 +453,7 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel): ...@@ -448,7 +453,7 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -496,7 +501,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel): ...@@ -496,7 +501,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -555,6 +560,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -555,6 +560,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
mc_token_ids=None, mc_token_ids=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -577,7 +583,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -577,7 +583,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -617,7 +623,8 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -617,7 +623,8 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
assert len(inputs) <= 7, "Too many inputs." output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -626,7 +633,8 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -626,7 +633,8 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids) mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
assert len(inputs) <= 7, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 8, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -649,6 +657,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -649,6 +657,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
flat_position_ids, flat_position_ids,
head_mask, head_mask,
inputs_embeds, inputs_embeds,
output_attentions,
] ]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
......
...@@ -213,7 +213,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -213,7 +213,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -289,7 +289,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel): ...@@ -289,7 +289,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -365,6 +365,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -365,6 +365,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -377,7 +378,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -377,7 +378,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -403,6 +404,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -403,6 +404,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -452,6 +454,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -452,6 +454,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -471,7 +474,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -471,7 +474,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -576,6 +579,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -576,6 +579,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -592,7 +596,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -592,7 +596,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -618,6 +622,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -618,6 +622,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -663,6 +668,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -663,6 +668,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
is_impossible=None, is_impossible=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -686,7 +692,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -686,7 +692,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -717,6 +723,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -717,6 +723,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
......
...@@ -25,7 +25,13 @@ import tensorflow as tf ...@@ -25,7 +25,13 @@ import tensorflow as tf
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list from .modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
...@@ -105,7 +111,6 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -105,7 +111,6 @@ class TFT5Attention(tf.keras.layers.Layer):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model self.d_model = config.d_model
self.d_kv = config.d_kv self.d_kv = config.d_kv
...@@ -198,6 +203,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -198,6 +203,7 @@ class TFT5Attention(tf.keras.layers.Layer):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
training=False, training=False,
output_attentions=False,
): ):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
...@@ -250,13 +256,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -250,13 +256,7 @@ class TFT5Attention(tf.keras.layers.Layer):
k, v = past_key_value_state k, v = past_key_value_state
# to cope with keras serialization # to cope with keras serialization
# we need to cast `use_cache` to correct bool use_cache = cast_bool_to_primitive(use_cache)
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if self.is_decoder and use_cache is True: if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),) present_key_value_state = ((k, v),)
...@@ -293,7 +293,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -293,7 +293,7 @@ class TFT5Attention(tf.keras.layers.Layer):
outputs = (context,) + present_key_value_state outputs = (context,) + present_key_value_state
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (weights,) outputs = outputs + (weights,)
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
outputs = outputs + (position_bias,) outputs = outputs + (position_bias,)
...@@ -317,6 +317,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): ...@@ -317,6 +317,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
head_mask=None, head_mask=None,
past_key_value_state=None, past_key_value_state=None,
use_cache=False, use_cache=False,
output_attentions=False,
training=False, training=False,
): ):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
...@@ -327,6 +328,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): ...@@ -327,6 +328,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
training=training, training=training,
) )
y = attention_output[0] y = attention_output[0]
...@@ -354,6 +356,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): ...@@ -354,6 +356,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
past_key_value_state=None, past_key_value_state=None,
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False,
training=False, training=False,
): ):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
...@@ -366,6 +369,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): ...@@ -366,6 +369,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
training=training, training=training,
) )
y = attention_output[0] y = attention_output[0]
...@@ -402,6 +406,7 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -402,6 +406,7 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask=None, head_mask=None,
past_key_value_state=None, past_key_value_state=None,
use_cache=False, use_cache=False,
output_attentions=False,
training=False, training=False,
): ):
...@@ -428,6 +433,7 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -428,6 +433,7 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state, past_key_value_state=self_attn_past_key_value_state,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
training=training, training=training,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, present_key_value_state = self_attention_outputs[:2]
...@@ -450,6 +456,7 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -450,6 +456,7 @@ class TFT5Block(tf.keras.layers.Layer):
past_key_value_state=cross_attn_past_key_value_state, past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
training=training, training=training,
) )
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
...@@ -509,8 +516,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -509,8 +516,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, embed_tokens=None, **kwargs): def __init__(self, config, embed_tokens=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
...@@ -550,6 +557,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -550,6 +557,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
head_mask=None, head_mask=None,
past_key_value_states=None, past_key_value_states=None,
use_cache=False, use_cache=False,
output_attentions=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -560,7 +568,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -560,7 +568,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
assert len(inputs) <= 7, "Too many inputs." output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("decoder_input_ids") input_ids = inputs.get("decoder_input_ids")
attention_mask = inputs.get("decoder_attention_mask", attention_mask) attention_mask = inputs.get("decoder_attention_mask", attention_mask)
...@@ -569,10 +578,13 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -569,10 +578,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states) past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
assert len(inputs) <= 7, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 8, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time") raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -697,6 +709,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -697,6 +709,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
head_mask=head_mask[i], head_mask=head_mask[i],
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
training=training, training=training,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
...@@ -705,13 +718,13 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -705,13 +718,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if i == 0: if i == 0:
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2] position_bias = layer_outputs[3 if output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3] encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
# append next layer key value states # append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
all_attentions = all_attentions + (layer_outputs[2],) all_attentions = all_attentions + (layer_outputs[2],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -727,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -727,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
outputs = outputs + (present_key_value_states,) outputs = outputs + (present_key_value_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
...@@ -896,7 +909,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -896,7 +909,7 @@ class TFT5Model(TFT5PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -931,11 +944,16 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -931,11 +944,16 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
use_cache = kwargs.get("use_cache", True) use_cache = kwargs.get("use_cache", True)
head_mask = kwargs.get("head_mask", None) head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None)
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, inputs,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
...@@ -958,6 +976,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -958,6 +976,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
if use_cache is True: if use_cache is True:
...@@ -1018,7 +1037,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): ...@@ -1018,7 +1037,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1057,12 +1076,17 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): ...@@ -1057,12 +1076,17 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
inputs_embeds = kwargs.get("inputs_embeds", None) inputs_embeds = kwargs.get("inputs_embeds", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
head_mask = kwargs.get("head_mask", None) head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None)
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed # Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, inputs,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
...@@ -1085,6 +1109,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): ...@@ -1085,6 +1109,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
# insert decoder past at right place # insert decoder past at right place
......
...@@ -24,7 +24,13 @@ import tensorflow as tf ...@@ -24,7 +24,13 @@ import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list from .modeling_tf_utils import (
TFPreTrainedModel,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
...@@ -109,14 +115,12 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -109,14 +115,12 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
pre_lnorm=False, pre_lnorm=False,
r_r_bias=None, r_r_bias=None,
r_w_bias=None, r_w_bias=None,
output_attentions=False,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
init_std=0.02, init_std=0.02,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = output_attentions
self.n_head = n_head self.n_head = n_head
self.d_model = d_model self.d_model = d_model
self.d_head = d_head self.d_head = d_head
...@@ -170,7 +174,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -170,7 +174,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
return x return x
def call(self, inputs, training=False): def call(self, inputs, training=False):
w, r, attn_mask, mems, head_mask = inputs w, r, attn_mask, mems, head_mask, output_attentions = inputs
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1] qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
if mems is not None: if mems is not None:
...@@ -243,7 +247,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -243,7 +247,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
# residual connection + layer normalization # residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)] outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs.append(attn_prob) outputs.append(attn_prob)
return outputs return outputs
...@@ -264,7 +268,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -264,7 +268,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
pre_lnorm=False, pre_lnorm=False,
r_w_bias=None, r_w_bias=None,
r_r_bias=None, r_r_bias=None,
output_attentions=False,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
init_std=0.02, init_std=0.02,
**kwargs **kwargs
...@@ -284,7 +287,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -284,7 +287,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
r_w_bias=r_w_bias, r_w_bias=r_w_bias,
r_r_bias=r_r_bias, r_r_bias=r_r_bias,
init_std=init_std, init_std=init_std,
output_attentions=output_attentions,
layer_norm_epsilon=layer_norm_epsilon, layer_norm_epsilon=layer_norm_epsilon,
name="dec_attn", name="dec_attn",
) )
...@@ -299,8 +301,10 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -299,8 +301,10 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
) )
def call(self, inputs, training=False): def call(self, inputs, training=False):
dec_inp, r, dec_attn_mask, mems, head_mask = inputs dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions = inputs
attn_outputs = self.dec_attn([dec_inp, r, dec_attn_mask, mems, head_mask], training=training) attn_outputs = self.dec_attn(
[dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions], training=training
)
ff_output = self.pos_ff(attn_outputs[0], training=training) ff_output = self.pos_ff(attn_outputs[0], training=training)
outputs = [ff_output] + attn_outputs[1:] outputs = [ff_output] + attn_outputs[1:]
...@@ -386,8 +390,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -386,8 +390,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.n_token = config.vocab_size self.n_token = config.vocab_size
...@@ -435,7 +439,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -435,7 +439,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
pre_lnorm=config.pre_lnorm, pre_lnorm=config.pre_lnorm,
r_w_bias=None if self.untie_r else self.r_w_bias, r_w_bias=None if self.untie_r else self.r_w_bias,
r_r_bias=None if self.untie_r else self.r_r_bias, r_r_bias=None if self.untie_r else self.r_r_bias,
output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon, layer_norm_epsilon=config.layer_norm_epsilon,
init_std=config.init_std, init_std=config.init_std,
name="layers_._{}".format(i), name="layers_._{}".format(i),
...@@ -514,22 +517,26 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -514,22 +517,26 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return new_mems return new_mems
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, training=False): def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, output_attentions=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else mems mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs." output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
assert len(inputs) <= 5, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems) mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 4, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 5, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz] # so we transpose here from shape [bsz, len] to shape [len, bsz]
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
...@@ -600,9 +607,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -600,9 +607,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out) hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
layer_outputs = layer([core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i]], training=training) layer_outputs = layer(
[core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions], training=training,
)
core_out = layer_outputs[0] core_out = layer_outputs[0]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attentions.append(layer_outputs[1]) attentions.append(layer_outputs[1])
else: # learnable embeddings and absolute embeddings else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
...@@ -618,7 +627,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -618,7 +627,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
hids.append(core_out) hids.append(core_out)
hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids) hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
outputs.append(hids) outputs.append(hids)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs.append(attentions) outputs.append(attentions)
...@@ -711,7 +720,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel): ...@@ -711,7 +720,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -785,7 +794,16 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -785,7 +794,16 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return self.transformer.init_mems(bsz) return self.transformer.init_mems(bsz)
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, labels=None, training=False): def call(
self,
inputs,
mems=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r""" r"""
Return: Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs: :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
...@@ -800,7 +818,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -800,7 +818,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -825,14 +843,16 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -825,14 +843,16 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
labels = inputs[4] if len(inputs) > 4 else labels labels = inputs[4] if len(inputs) > 4 else labels
assert len(inputs) <= 5, "Too many inputs." output_attentions = inputs[5] if len(inputs) > 5 else output_attentions
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems) mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
labels = inputs.get("labels", labels) labels = inputs.get("labels", labels)
assert len(inputs) <= 5, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 6, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -841,7 +861,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -841,7 +861,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
else: else:
bsz, tgt_len = shape_list(inputs_embeds)[:2] bsz, tgt_len = shape_list(inputs_embeds)[:2]
transformer_outputs = self.transformer([input_ids, mems, head_mask, inputs_embeds], training=training) transformer_outputs = self.transformer(
[input_ids, mems, head_mask, inputs_embeds, output_attentions], training=training
)
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
......
...@@ -1755,3 +1755,24 @@ def get_initializer(initializer_range=0.02): ...@@ -1755,3 +1755,24 @@ def get_initializer(initializer_range=0.02):
TruncatedNormal initializer with stddev = `initializer_range`. TruncatedNormal initializer with stddev = `initializer_range`.
""" """
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
def cast_bool_to_primitive(bool_variable, default_tensor_to_true=False):
"""Function arguments can be inserted as boolean tensor
and bool variables to cope with keras serialization
we need to cast `output_attentions` to correct bool
if it is a tensor
Args:
default_tensor_to_true: bool, if tensor should default to True
in case tensor has no numpy attribute
"""
# if bool variable is tensor and has numpy value
if tf.is_tensor(bool_variable):
if hasattr(bool_variable, "numpy"):
return bool(bool_variable.numpy())
elif default_tensor_to_true:
return True
# else variable is bool
return bool_variable
...@@ -33,6 +33,7 @@ from .modeling_tf_utils import ( ...@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -112,7 +113,6 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -112,7 +113,6 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, n_heads, dim, config, **kwargs): def __init__(self, n_heads, dim, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.layer_id = next(TFMultiHeadAttention.NEW_ID) self.layer_id = next(TFMultiHeadAttention.NEW_ID)
self.output_attentions = config.output_attentions
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
assert self.dim % self.n_heads == 0 assert self.dim % self.n_heads == 0
...@@ -131,7 +131,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -131,7 +131,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
""" """
input, mask, kv, cache, head_mask = inputs input, mask, kv, cache, head_mask, output_attentions = inputs
# Input is (bs, qlen, dim) # Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen) # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input) bs, qlen, dim = shape_list(input)
...@@ -188,7 +188,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -188,7 +188,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
context = unshape(context) # (bs, qlen, dim) context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),) outputs = (self.out_lin(context),)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (weights,) outputs = outputs + (weights,)
return outputs return outputs
...@@ -215,8 +215,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -215,8 +215,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
# encoder / decoder, output layer # encoder / decoder, output layer
self.is_encoder = config.is_encoder self.is_encoder = config.is_encoder
...@@ -327,6 +327,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -327,6 +327,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
cache=None, cache=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
training=False, training=False,
): # removed: src_enc=None, src_len=None ): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -339,7 +340,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -339,7 +340,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
cache = inputs[6] if len(inputs) > 6 else cache cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -350,10 +352,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -350,10 +352,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
cache = inputs.get("cache", cache) cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -440,9 +445,11 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -440,9 +445,11 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# self attention # self attention
attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training) attn_outputs = self.attentions[i](
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
)
attn = attn_outputs[0] attn = attn_outputs[0]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -474,7 +481,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -474,7 +481,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
outputs = (tensor,) outputs = (tensor,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions) return outputs # outputs, (hidden_states), (attentions)
...@@ -602,7 +609,7 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -602,7 +609,7 @@ class TFXLMModel(TFXLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -698,7 +705,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -698,7 +705,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -752,6 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -752,6 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -770,7 +778,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -770,7 +778,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -800,6 +808,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -800,6 +808,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
cache=cache, cache=cache,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -849,6 +858,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -849,6 +858,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -868,7 +878,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -868,7 +878,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -900,7 +910,8 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -900,7 +910,8 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
cache = inputs[6] if len(inputs) > 6 else cache cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -911,7 +922,8 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -911,7 +922,8 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
cache = inputs.get("cache", cache) cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -937,6 +949,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -937,6 +949,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
cache, cache,
head_mask, head_mask,
inputs_embeds, inputs_embeds,
output_attentions,
] ]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
...@@ -982,6 +995,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -982,6 +995,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -998,7 +1012,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -998,7 +1012,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1024,6 +1038,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1024,6 +1038,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -1071,6 +1086,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1071,6 +1086,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
is_impossible=None, is_impossible=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1094,7 +1110,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1094,7 +1110,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1127,6 +1143,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1127,6 +1143,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
cache=cache, cache=cache,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
......
...@@ -32,6 +32,7 @@ from .modeling_tf_utils import ( ...@@ -32,6 +32,7 @@ from .modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -71,7 +72,6 @@ ACT2FN = { ...@@ -71,7 +72,6 @@ ACT2FN = {
class TFXLNetRelativeAttention(tf.keras.layers.Layer): class TFXLNetRelativeAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
if config.d_model % config.n_head != 0: if config.d_model % config.n_head != 0:
raise ValueError( raise ValueError(
...@@ -137,7 +137,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -137,7 +137,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
def rel_attn_core(self, inputs, training=False): def rel_attn_core(self, inputs, training=False):
"""Core relative positional attention operations.""" """Core relative positional attention operations."""
q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask = inputs q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions = inputs
# content based attention score # content based attention score
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h) ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
...@@ -174,7 +174,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -174,7 +174,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# attention output # attention output
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h) attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
return attn_vec, attn_prob return attn_vec, attn_prob
return attn_vec return attn_vec
...@@ -195,7 +195,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -195,7 +195,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
return output return output
def call(self, inputs, training=False): def call(self, inputs, training=False):
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs (h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions) = inputs
if g is not None: if g is not None:
# Two-stream attention with relative positional encoding. # Two-stream attention with relative positional encoding.
...@@ -220,10 +220,11 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -220,10 +220,11 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# core attention ops # core attention ops
attn_vec_h = self.rel_attn_core( attn_vec_h = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], training=training [q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
training=training,
) )
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attn_vec_h, attn_prob_h = attn_vec_h attn_vec_h, attn_prob_h = attn_vec_h
# post processing # post processing
...@@ -237,25 +238,27 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -237,25 +238,27 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
if target_mapping is not None: if target_mapping is not None:
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping) q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core( attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], training=training [q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
training=training,
) )
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attn_vec_g, attn_prob_g = attn_vec_g attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping) attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else: else:
attn_vec_g = self.rel_attn_core( attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], training=training [q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
training=training,
) )
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attn_vec_g, attn_prob_g = attn_vec_g attn_vec_g, attn_prob_g = attn_vec_g
# post processing # post processing
output_g = self.post_attention([g, attn_vec_g], training=training) output_g = self.post_attention([g, attn_vec_g], training=training)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attn_prob = attn_prob_h, attn_prob_g attn_prob = attn_prob_h, attn_prob_g
else: else:
...@@ -275,10 +278,11 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -275,10 +278,11 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# core attention ops # core attention ops
attn_vec = self.rel_attn_core( attn_vec = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], training=training [q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
training=training,
) )
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attn_vec, attn_prob = attn_vec attn_vec, attn_prob = attn_vec
# post processing # post processing
...@@ -286,7 +290,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -286,7 +290,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
output_g = None output_g = None
outputs = (output_h, output_g) outputs = (output_h, output_g)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (attn_prob,) outputs = outputs + (attn_prob,)
return outputs return outputs
...@@ -361,8 +365,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -361,8 +365,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -508,6 +512,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -508,6 +512,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
output_attentions=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -521,7 +526,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -521,7 +526,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache use_cache = inputs[9] if len(inputs) > 9 else use_cache
assert len(inputs) <= 10, "Too many inputs." output_attentions = inputs[-9] if len(inputs) > 10 else output_attentions
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -533,10 +539,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -533,10 +539,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 10, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 11, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
...@@ -668,11 +677,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -668,11 +677,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module( outputs = layer_module(
[output_h, output_g, non_tgt_mask, attn_mask, pos_emb, seg_mat, mems[i], target_mapping, head_mask[i]], [
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems[i],
target_mapping,
head_mask[i],
output_attentions,
],
training=training, training=training,
) )
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attentions.append(outputs[2]) attentions.append(outputs[2])
# Add last hidden state # Add last hidden state
...@@ -693,7 +713,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -693,7 +713,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else: else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
...@@ -817,7 +837,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -817,7 +837,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -901,7 +921,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -901,7 +921,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -969,6 +989,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -969,6 +989,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -991,7 +1012,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -991,7 +1012,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1022,6 +1043,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1022,6 +1043,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1077,6 +1099,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1077,6 +1099,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1096,7 +1119,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1096,7 +1119,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1129,7 +1152,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1129,7 +1152,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache use_cache = inputs[9] if len(inputs) > 9 else use_cache
assert len(inputs) <= 10, "Too many inputs." output_attentions = inputs[-9] if len(inputs) > 10 else output_attentions
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -1141,7 +1165,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1141,7 +1165,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 10, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 11, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -1168,6 +1193,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1168,6 +1193,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask, head_mask,
inputs_embeds, inputs_embeds,
use_cache, use_cache,
output_attentions,
] ]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
...@@ -1213,6 +1239,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1213,6 +1239,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1233,7 +1260,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1233,7 +1260,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1264,6 +1291,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1264,6 +1291,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
training=training, training=training,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1310,6 +1338,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1310,6 +1338,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
is_impossible=None, is_impossible=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1339,7 +1368,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1339,7 +1368,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1372,6 +1401,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1372,6 +1401,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -1425,7 +1455,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1425,7 +1455,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
# list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) # list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
# of shape ``(batch_size, sequence_length, hidden_size)``: # of shape ``(batch_size, sequence_length, hidden_size)``:
# Hidden-states of the model at the output of each layer plus the initial embedding outputs. # Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# **attentions**: (`optional`, returned when ``config.output_attentions=True``) # **attentions**: (`optional`, returned when ``output_attentions=True``)
# list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: # list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. # Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -288,7 +288,7 @@ class TFXxxModel(TFXxxPreTrainedModel): ...@@ -288,7 +288,7 @@ class TFXxxModel(TFXxxPreTrainedModel):
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
...@@ -329,7 +329,7 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel): ...@@ -329,7 +329,7 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
...@@ -378,7 +378,7 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel): ...@@ -378,7 +378,7 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
...@@ -433,7 +433,7 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel): ...@@ -433,7 +433,7 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
...@@ -490,7 +490,7 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel): ...@@ -490,7 +490,7 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......
This diff is collapsed.
...@@ -296,7 +296,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -296,7 +296,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False # head masking is not supported test_headmasking = False # head masking is not supported
test_torchscript = False test_torchscript = False
all_model_classes = (LongformerForMaskedLM, LongformerModel) if is_torch_available() else () all_model_classes = (LongformerModel, LongformerForMaskedLM,) if is_torch_available() else ()
def setUp(self): def setUp(self):
self.model_tester = LongformerModelTester(self) self.model_tester = LongformerModelTester(self)
......
This diff is collapsed.
...@@ -238,7 +238,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -238,7 +238,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
_, _, attentions = model(input_ids_1, target_mapping=target_mapping) _, _, attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)
self.parent.assertEqual(len(attentions), config.n_layer) self.parent.assertEqual(len(attentions), config.n_layer)
self.parent.assertIsInstance(attentions[0], tuple) self.parent.assertIsInstance(attentions[0], tuple)
...@@ -483,7 +483,6 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -483,7 +483,6 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_xlnet_base_model_with_att_output(self): def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed() self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].output_attentions = True
self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs) self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs)
def test_xlnet_lm_head(self): def test_xlnet_lm_head(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment