"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "1555f298c13c392d8c666844183ae9bfd834c840"
Unverified Commit 6e603cb7 authored by Bharat Raghunathan's avatar Bharat Raghunathan Committed by GitHub
Browse files

[All models] Extend config.output_attentions with output_attentions function arguments (#4538)



* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions``

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* Fix further regressions in tests relating to `output_attentions`

Ensure proper propagation of `output_attentions` as a function parameter
to all model subclasses

* Fix more regressions in `test_output_attentions`

* Fix issues with BertEncoder

* Rename related variables to `output_attentions`

* fix pytorch tests

* fix bert and gpt2 tf

* Fix most TF tests for `test_output_attentions`

* Fix linter errors and more TF tests

* fix conflicts

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix pytorch tests

* fix conflicts

* fix conflicts

* Fix linter errors and more TF tests

* fix tf tests

* make style

* fix isort

* improve output_attentions

* improve tensorflow
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f90bc44d
......@@ -134,6 +134,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
head_mask=None,
inputs_embeds=None,
training=False,
output_attentions=False,
):
# removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
......@@ -255,7 +256,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
if not self.pre_norm:
attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training)
attn = attn_outputs[0]
if self.output_attentions:
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
tensor = tensor + attn
......@@ -266,7 +267,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training
)
attn = attn_outputs[0]
if self.output_attentions:
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
tensor = tensor + attn
......@@ -302,7 +303,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
outputs = (tensor,)
if self.output_hidden_states:
outputs = outputs + (hidden_states,)
if self.output_attentions:
if output_attentions:
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)
......
......@@ -28,6 +28,7 @@ from .modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
......@@ -63,7 +64,6 @@ def gelu(x):
class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
......@@ -93,7 +93,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.cast(m, dtype)
def _attn(self, inputs, training=False):
q, k, v, attention_mask, head_mask = inputs
q, k, v, attention_mask, head_mask, output_attentions = inputs
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
if self.scale:
......@@ -118,7 +118,7 @@ class TFAttention(tf.keras.layers.Layer):
w = w * head_mask
outputs = [tf.matmul(w, v)]
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs.append(w)
return outputs
......@@ -135,7 +135,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask, use_cache = inputs
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2)
......@@ -148,20 +148,12 @@ class TFAttention(tf.keras.layers.Layer):
value = tf.concat([past_value, value], axis=-2)
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True:
if cast_bool_to_primitive(use_cache, True) is True:
present = tf.stack([key, value], axis=0)
else:
present = (None,)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
a = attn_outputs[0]
a = self.merge_heads(a)
......@@ -198,10 +190,12 @@ class TFBlock(tf.keras.layers.Layer):
self.mlp = TFMLP(4 * nx, config, name="mlp")
def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask, use_cache = inputs
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
a = self.ln_1(x)
output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training)
output_attn = self.attn(
[a, layer_past, attention_mask, head_mask, use_cache, output_attentions], training=training
)
a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a
......@@ -219,8 +213,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size
self.n_embd = config.n_embd
......@@ -261,6 +255,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
inputs_embeds=None,
use_cache=True,
training=False,
output_attentions=None,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
......@@ -271,7 +266,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
......@@ -281,10 +277,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -355,12 +354,15 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training)
outputs = block(
[hidden_states, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
training=training,
)
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states)
......@@ -376,7 +378,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
......@@ -493,7 +495,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -552,7 +554,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -614,6 +616,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs_embeds=None,
mc_token_ids=None,
use_cache=True,
output_attentions=None,
training=False,
):
r"""
......@@ -636,7 +639,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -681,7 +684,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
use_cache = inputs[8] if len(inputs) > 8 else use_cache
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs[9] if len(inputs) > 8 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
......@@ -692,7 +696,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
......@@ -717,6 +722,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask,
inputs_embeds,
use_cache,
output_attentions,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
......
......@@ -28,6 +28,7 @@ from .modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
......@@ -70,7 +71,6 @@ ACT_FNS = {
class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
......@@ -100,7 +100,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.cast(m, dtype)
def _attn(self, inputs, training=False):
q, k, v, attention_mask, head_mask = inputs
q, k, v, attention_mask, head_mask, output_attentions = inputs
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
if self.scale:
......@@ -125,7 +125,7 @@ class TFAttention(tf.keras.layers.Layer):
w = w * head_mask
outputs = [tf.matmul(w, v)]
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs.append(w)
return outputs
......@@ -142,7 +142,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False):
x, attention_mask, head_mask = inputs
x, attention_mask, head_mask, output_attentions = inputs
x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2)
......@@ -150,7 +150,7 @@ class TFAttention(tf.keras.layers.Layer):
key = self.split_heads(key)
value = self.split_heads(value)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
a = attn_outputs[0]
a = self.merge_heads(a)
......@@ -187,9 +187,9 @@ class TFBlock(tf.keras.layers.Layer):
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
def call(self, inputs, training=False):
x, attention_mask, head_mask = inputs
x, attention_mask, head_mask, output_attentions = inputs
output_attn = self.attn([x, attention_mask, head_mask], training=training)
output_attn = self.attn([x, attention_mask, head_mask, output_attentions], training=training)
a = output_attn[0] # output_attn: a, (attentions)
n = self.ln_1(x + a)
......@@ -244,6 +244,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
......@@ -253,7 +254,8 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -261,10 +263,13 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -329,9 +334,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, attention_mask, head_mask[i]], training=training)
outputs = block([hidden_states, attention_mask, head_mask[i], output_attentions], training=training)
hidden_states = outputs[0]
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
all_attentions.append(outputs[1])
hidden_states = tf.reshape(hidden_states, output_shape)
......@@ -342,7 +347,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
......@@ -448,7 +453,7 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -496,7 +501,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -555,6 +560,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
head_mask=None,
inputs_embeds=None,
mc_token_ids=None,
output_attentions=None,
training=False,
):
r"""
......@@ -577,7 +583,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -617,7 +623,8 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
assert len(inputs) <= 7, "Too many inputs."
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -626,7 +633,8 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
assert len(inputs) <= 7, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
......@@ -649,6 +657,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
flat_position_ids,
head_mask,
inputs_embeds,
output_attentions,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
......
......@@ -213,7 +213,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -289,7 +289,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -365,6 +365,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -377,7 +378,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -403,6 +404,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training,
)
......@@ -452,6 +454,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -471,7 +474,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -576,6 +579,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -592,7 +596,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -618,6 +622,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training,
)
......@@ -663,6 +668,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
training=False,
):
r"""
......@@ -686,7 +692,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -717,6 +723,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training,
)
......
......@@ -25,7 +25,13 @@ import tensorflow as tf
from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
from .modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
......@@ -105,7 +111,6 @@ class TFT5Attention(tf.keras.layers.Layer):
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model
self.d_kv = config.d_kv
......@@ -198,6 +203,7 @@ class TFT5Attention(tf.keras.layers.Layer):
query_length=None,
use_cache=False,
training=False,
output_attentions=False,
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
......@@ -250,13 +256,7 @@ class TFT5Attention(tf.keras.layers.Layer):
k, v = past_key_value_state
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
use_cache = cast_bool_to_primitive(use_cache)
if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),)
......@@ -293,7 +293,7 @@ class TFT5Attention(tf.keras.layers.Layer):
outputs = (context,) + present_key_value_state
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
......@@ -317,6 +317,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
head_mask=None,
past_key_value_state=None,
use_cache=False,
output_attentions=False,
training=False,
):
norm_x = self.layer_norm(hidden_states)
......@@ -327,6 +328,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
head_mask=head_mask,
past_key_value_state=past_key_value_state,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
y = attention_output[0]
......@@ -354,6 +356,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
past_key_value_state=None,
query_length=None,
use_cache=False,
output_attentions=False,
training=False,
):
norm_x = self.layer_norm(hidden_states)
......@@ -366,6 +369,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
past_key_value_state=past_key_value_state,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
y = attention_output[0]
......@@ -402,6 +406,7 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask=None,
past_key_value_state=None,
use_cache=False,
output_attentions=False,
training=False,
):
......@@ -428,6 +433,7 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
......@@ -450,6 +456,7 @@ class TFT5Block(tf.keras.layers.Layer):
past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
hidden_states = cross_attention_outputs[0]
......@@ -509,8 +516,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, embed_tokens=None, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
......@@ -550,6 +557,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
head_mask=None,
past_key_value_states=None,
use_cache=False,
output_attentions=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
......@@ -560,7 +568,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
head_mask = inputs[5] if len(inputs) > 5 else head_mask
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
assert len(inputs) <= 7, "Too many inputs."
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("decoder_input_ids")
attention_mask = inputs.get("decoder_attention_mask", attention_mask)
......@@ -569,10 +578,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
assert len(inputs) <= 7, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -697,6 +709,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
head_mask=head_mask[i],
past_key_value_state=past_key_value_state,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
# layer_outputs is a tuple with:
......@@ -705,13 +718,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2]
position_bias = layer_outputs[3 if output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3]
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
all_attentions = all_attentions + (layer_outputs[2],)
hidden_states = self.final_layer_norm(hidden_states)
......@@ -727,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
outputs = outputs + (present_key_value_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
......@@ -896,7 +909,7 @@ class TFT5Model(TFT5PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -931,11 +944,16 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
use_cache = kwargs.get("use_cache", True)
head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None)
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask,
inputs,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
)
hidden_states = encoder_outputs[0]
......@@ -958,6 +976,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
if use_cache is True:
......@@ -1018,7 +1037,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1057,12 +1076,17 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
inputs_embeds = kwargs.get("inputs_embeds", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None)
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
inputs, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask,
inputs,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
)
hidden_states = encoder_outputs[0]
......@@ -1085,6 +1109,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
# insert decoder past at right place
......
......@@ -24,7 +24,13 @@ import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .modeling_tf_utils import (
TFPreTrainedModel,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
......@@ -109,14 +115,12 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
pre_lnorm=False,
r_r_bias=None,
r_w_bias=None,
output_attentions=False,
layer_norm_epsilon=1e-5,
init_std=0.02,
**kwargs
):
super().__init__(**kwargs)
self.output_attentions = output_attentions
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
......@@ -170,7 +174,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
return x
def call(self, inputs, training=False):
w, r, attn_mask, mems, head_mask = inputs
w, r, attn_mask, mems, head_mask, output_attentions = inputs
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
if mems is not None:
......@@ -243,7 +247,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
# residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs.append(attn_prob)
return outputs
......@@ -264,7 +268,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
pre_lnorm=False,
r_w_bias=None,
r_r_bias=None,
output_attentions=False,
layer_norm_epsilon=1e-5,
init_std=0.02,
**kwargs
......@@ -284,7 +287,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
r_w_bias=r_w_bias,
r_r_bias=r_r_bias,
init_std=init_std,
output_attentions=output_attentions,
layer_norm_epsilon=layer_norm_epsilon,
name="dec_attn",
)
......@@ -299,8 +301,10 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
)
def call(self, inputs, training=False):
dec_inp, r, dec_attn_mask, mems, head_mask = inputs
attn_outputs = self.dec_attn([dec_inp, r, dec_attn_mask, mems, head_mask], training=training)
dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions = inputs
attn_outputs = self.dec_attn(
[dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions], training=training
)
ff_output = self.pos_ff(attn_outputs[0], training=training)
outputs = [ff_output] + attn_outputs[1:]
......@@ -386,8 +390,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.n_token = config.vocab_size
......@@ -435,7 +439,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
pre_lnorm=config.pre_lnorm,
r_w_bias=None if self.untie_r else self.r_w_bias,
r_r_bias=None if self.untie_r else self.r_r_bias,
output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon,
init_std=config.init_std,
name="layers_._{}".format(i),
......@@ -514,22 +517,26 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return new_mems
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, training=False):
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, output_attentions=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
assert len(inputs) <= 5, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 4, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
if input_ids is not None and inputs_embeds is not None:
......@@ -600,9 +607,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
layer_outputs = layer([core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i]], training=training)
layer_outputs = layer(
[core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions], training=training,
)
core_out = layer_outputs[0]
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attentions.append(layer_outputs[1])
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
......@@ -618,7 +627,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
hids.append(core_out)
hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
outputs.append(hids)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs.append(attentions)
......@@ -711,7 +720,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -785,7 +794,16 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return self.transformer.init_mems(bsz)
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, labels=None, training=False):
def call(
self,
inputs,
mems=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r"""
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
......@@ -800,7 +818,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -825,14 +843,16 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
labels = inputs[4] if len(inputs) > 4 else labels
assert len(inputs) <= 5, "Too many inputs."
output_attentions = inputs[5] if len(inputs) > 5 else output_attentions
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
labels = inputs.get("labels", labels)
assert len(inputs) <= 5, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
......@@ -841,7 +861,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
else:
bsz, tgt_len = shape_list(inputs_embeds)[:2]
transformer_outputs = self.transformer([input_ids, mems, head_mask, inputs_embeds], training=training)
transformer_outputs = self.transformer(
[input_ids, mems, head_mask, inputs_embeds, output_attentions], training=training
)
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]
......
......@@ -1755,3 +1755,24 @@ def get_initializer(initializer_range=0.02):
TruncatedNormal initializer with stddev = `initializer_range`.
"""
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
def cast_bool_to_primitive(bool_variable, default_tensor_to_true=False):
"""Function arguments can be inserted as boolean tensor
and bool variables to cope with keras serialization
we need to cast `output_attentions` to correct bool
if it is a tensor
Args:
default_tensor_to_true: bool, if tensor should default to True
in case tensor has no numpy attribute
"""
# if bool variable is tensor and has numpy value
if tf.is_tensor(bool_variable):
if hasattr(bool_variable, "numpy"):
return bool(bool_variable.numpy())
elif default_tensor_to_true:
return True
# else variable is bool
return bool_variable
......@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
......@@ -112,7 +113,6 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, n_heads, dim, config, **kwargs):
super().__init__(**kwargs)
self.layer_id = next(TFMultiHeadAttention.NEW_ID)
self.output_attentions = config.output_attentions
self.dim = dim
self.n_heads = n_heads
assert self.dim % self.n_heads == 0
......@@ -131,7 +131,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
"""
input, mask, kv, cache, head_mask = inputs
input, mask, kv, cache, head_mask, output_attentions = inputs
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input)
......@@ -188,7 +188,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (weights,)
return outputs
......@@ -215,8 +215,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
# encoder / decoder, output layer
self.is_encoder = config.is_encoder
......@@ -327,6 +327,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
training=False,
): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
......@@ -339,7 +340,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -350,10 +352,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -440,9 +445,11 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,)
# self attention
attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training)
attn_outputs = self.attentions[i](
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
)
attn = attn_outputs[0]
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
tensor = tensor + attn
......@@ -474,7 +481,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
outputs = (tensor,)
if self.output_hidden_states:
outputs = outputs + (hidden_states,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)
......@@ -602,7 +609,7 @@ class TFXLMModel(TFXLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -698,7 +705,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -752,6 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -770,7 +778,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -800,6 +808,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training,
)
output = transformer_outputs[0]
......@@ -849,6 +858,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -868,7 +878,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -900,7 +910,8 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -911,7 +922,8 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
......@@ -937,6 +949,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
cache,
head_mask,
inputs_embeds,
output_attentions,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
......@@ -982,6 +995,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -998,7 +1012,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -1024,6 +1038,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training,
)
......@@ -1071,6 +1086,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
training=False,
):
r"""
......@@ -1094,7 +1110,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1127,6 +1143,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training,
)
......
......@@ -32,6 +32,7 @@ from .modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
......@@ -71,7 +72,6 @@ ACT2FN = {
class TFXLNetRelativeAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
if config.d_model % config.n_head != 0:
raise ValueError(
......@@ -137,7 +137,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
def rel_attn_core(self, inputs, training=False):
"""Core relative positional attention operations."""
q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask = inputs
q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions = inputs
# content based attention score
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
......@@ -174,7 +174,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# attention output
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
return attn_vec, attn_prob
return attn_vec
......@@ -195,7 +195,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
return output
def call(self, inputs, training=False):
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions) = inputs
if g is not None:
# Two-stream attention with relative positional encoding.
......@@ -220,10 +220,11 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# core attention ops
attn_vec_h = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], training=training
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
training=training,
)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attn_vec_h, attn_prob_h = attn_vec_h
# post processing
......@@ -237,25 +238,27 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
if target_mapping is not None:
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], training=training
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
training=training,
)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], training=training
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
training=training,
)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attn_vec_g, attn_prob_g = attn_vec_g
# post processing
output_g = self.post_attention([g, attn_vec_g], training=training)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attn_prob = attn_prob_h, attn_prob_g
else:
......@@ -275,10 +278,11 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# core attention ops
attn_vec = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], training=training
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
training=training,
)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attn_vec, attn_prob = attn_vec
# post processing
......@@ -286,7 +290,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
output_g = None
outputs = (output_h, output_g)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (attn_prob,)
return outputs
......@@ -361,8 +365,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
......@@ -508,6 +512,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
......@@ -521,7 +526,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
assert len(inputs) <= 10, "Too many inputs."
output_attentions = inputs[-9] if len(inputs) > 10 else output_attentions
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -533,10 +539,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 10, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
......@@ -668,11 +677,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module(
[output_h, output_g, non_tgt_mask, attn_mask, pos_emb, seg_mat, mems[i], target_mapping, head_mask[i]],
[
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems[i],
target_mapping,
head_mask[i],
output_attentions,
],
training=training,
)
output_h, output_g = outputs[:2]
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attentions.append(outputs[2])
# Add last hidden state
......@@ -693,7 +713,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
outputs = outputs + (hidden_states,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs = outputs + (attentions,)
......@@ -817,7 +837,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -901,7 +921,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -969,6 +989,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -991,7 +1012,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1022,6 +1043,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
output = transformer_outputs[0]
......@@ -1077,6 +1099,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -1096,7 +1119,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......@@ -1129,7 +1152,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
assert len(inputs) <= 10, "Too many inputs."
output_attentions = inputs[-9] if len(inputs) > 10 else output_attentions
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -1141,7 +1165,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 10, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
......@@ -1168,6 +1193,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask,
inputs_embeds,
use_cache,
output_attentions,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
......@@ -1213,6 +1239,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
training=False,
):
r"""
......@@ -1233,7 +1260,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1264,6 +1291,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
output = transformer_outputs[0]
......@@ -1310,6 +1338,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
cls_index=None,
p_mask=None,
is_impossible=None,
output_attentions=None,
training=False,
):
r"""
......@@ -1339,7 +1368,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1372,6 +1401,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
......@@ -1425,7 +1455,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
# list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
# of shape ``(batch_size, sequence_length, hidden_size)``:
# Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# **attentions**: (`optional`, returned when ``config.output_attentions=True``)
# **attentions**: (`optional`, returned when ``output_attentions=True``)
# list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......
......@@ -234,12 +234,10 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
pre_lnorm=False,
r_r_bias=None,
r_w_bias=None,
output_attentions=False,
layer_norm_epsilon=1e-5,
):
super().__init__()
self.output_attentions = output_attentions
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
......@@ -278,7 +276,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
return x
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attentions=False):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
......@@ -361,7 +359,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
# residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
if output_attentions:
outputs.append(attn_prob)
return outputs
......@@ -378,9 +376,11 @@ class RelPartialLearnableDecoderLayer(nn.Module):
d_model, d_inner, dropout, pre_lnorm=kwargs.get("pre_lnorm"), layer_norm_epsilon=layer_norm_epsilon
)
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None, output_attentions=False):
attn_outputs = self.dec_attn(dec_inp, r, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask)
attn_outputs = self.dec_attn(
dec_inp, r, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask, output_attentions=output_attentions,
)
ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
......@@ -552,7 +552,6 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
class TransfoXLModel(TransfoXLPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.n_token = config.vocab_size
......@@ -598,7 +597,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon,
)
)
......@@ -670,7 +668,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None):
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, output_attentions=None):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
......@@ -685,7 +683,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -704,6 +702,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
last_hidden_states, mems = outputs[:2]
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
if input_ids is not None and inputs_embeds is not None:
......@@ -772,10 +772,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
layer_outputs = layer(
core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]
core_out,
pos_emb,
dec_attn_mask=dec_attn_mask,
mems=mems_i,
head_mask=head_mask[i],
output_attentions=output_attentions,
)
core_out = layer_outputs[0]
if self.output_attentions:
if output_attentions:
attentions.append(layer_outputs[1])
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
......@@ -791,7 +796,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
hids.append(core_out)
hids = list(t.transpose(0, 1).contiguous() for t in hids)
outputs.append(hids)
if self.output_attentions:
if output_attentions:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions)
......@@ -848,7 +853,9 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return self.transformer.init_mems(bsz)
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None):
def forward(
self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for language modeling.
......@@ -872,7 +879,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -898,7 +905,9 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds)
transformer_outputs = self.transformer(
input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions
)
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]
......
......@@ -95,7 +95,6 @@ class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, dim, config):
super().__init__()
self.layer_id = next(MultiHeadAttention.NEW_ID)
self.output_attentions = config.output_attentions
self.dim = dim
self.n_heads = n_heads
self.dropout = config.attention_dropout
......@@ -122,7 +121,7 @@ class MultiHeadAttention(nn.Module):
self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
"""
......@@ -181,7 +180,7 @@ class MultiHeadAttention(nn.Module):
context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),)
if self.output_attentions:
if output_attentions:
outputs = outputs + (weights,)
return outputs
......@@ -313,7 +312,6 @@ XLM_INPUTS_DOCSTRING = r"""
class XLMModel(XLMPreTrainedModel):
def __init__(self, config): # , dico, is_encoder, with_output):
super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
# encoder / decoder, output layer
......@@ -407,6 +405,7 @@ class XLMModel(XLMPreTrainedModel):
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
):
r"""
Return:
......@@ -418,7 +417,7 @@ class XLMModel(XLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -437,6 +436,8 @@ class XLMModel(XLMPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None:
bs, slen = input_ids.size()
else:
......@@ -512,9 +513,11 @@ class XLMModel(XLMPreTrainedModel):
hidden_states = hidden_states + (tensor,)
# self attention
attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
attn_outputs = self.attentions[i](
tensor, attn_mask, cache=cache, head_mask=head_mask[i], output_attentions=output_attentions,
)
attn = attn_outputs[0]
if self.output_attentions:
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
......@@ -546,7 +549,7 @@ class XLMModel(XLMPreTrainedModel):
outputs = (tensor,)
if self.output_hidden_states:
outputs = outputs + (hidden_states,)
if self.output_attentions:
if output_attentions:
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)
......@@ -636,6 +639,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -656,7 +660,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -685,6 +689,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
output = transformer_outputs[0]
......@@ -722,6 +727,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -741,7 +747,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -771,6 +777,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
output = transformer_outputs[0]
......@@ -819,6 +826,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -843,7 +851,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -874,6 +882,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
sequence_output = transformer_outputs[0]
......@@ -940,6 +949,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
is_impossible=None,
cls_index=None,
p_mask=None,
output_attentions=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -977,7 +987,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1008,6 +1018,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
output = transformer_outputs[0]
......@@ -1052,6 +1063,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
position_ids=None,
head_mask=None,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -1069,7 +1081,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1096,6 +1108,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
output_attentions=output_attentions,
)
sequence_output = outputs[0]
......
......@@ -193,7 +193,6 @@ XLNetLayerNorm = nn.LayerNorm
class XLNetRelativeAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.output_attentions = config.output_attentions
if config.d_model % config.n_head != 0:
raise ValueError(
......@@ -251,7 +250,17 @@ class XLNetRelativeAttention(nn.Module):
return x
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None, head_mask=None):
def rel_attn_core(
self,
q_head,
k_head_h,
v_head_h,
k_head_r,
seg_mat=None,
attn_mask=None,
head_mask=None,
output_attentions=False,
):
"""Core relative positional attention operations."""
# content based attention score
......@@ -288,7 +297,7 @@ class XLNetRelativeAttention(nn.Module):
# attention output
attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
if self.output_attentions:
if output_attentions:
return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
return attn_vec
......@@ -305,7 +314,19 @@ class XLNetRelativeAttention(nn.Module):
return output
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
def forward(
self,
h,
g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems=None,
target_mapping=None,
head_mask=None,
output_attentions=False,
):
if g is not None:
# Two-stream attention with relative positional encoding.
# content based attention score
......@@ -329,10 +350,17 @@ class XLNetRelativeAttention(nn.Module):
# core attention ops
attn_vec_h = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat=seg_mat,
attn_mask=attn_mask_h,
head_mask=head_mask,
output_attentions=output_attentions,
)
if self.output_attentions:
if output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
# post processing
......@@ -346,25 +374,39 @@ class XLNetRelativeAttention(nn.Module):
if target_mapping is not None:
q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat=seg_mat,
attn_mask=attn_mask_g,
head_mask=head_mask,
output_attentions=output_attentions,
)
if self.output_attentions:
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat=seg_mat,
attn_mask=attn_mask_g,
head_mask=head_mask,
output_attentions=output_attentions,
)
if self.output_attentions:
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
# post processing
output_g = self.post_attention(g, attn_vec_g)
if self.output_attentions:
if output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
......@@ -384,10 +426,17 @@ class XLNetRelativeAttention(nn.Module):
# core attention ops
attn_vec = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat=seg_mat,
attn_mask=attn_mask_h,
head_mask=head_mask,
output_attentions=output_attentions,
)
if self.output_attentions:
if output_attentions:
attn_vec, attn_prob = attn_vec
# post processing
......@@ -395,7 +444,7 @@ class XLNetRelativeAttention(nn.Module):
output_g = None
outputs = (output_h, output_g)
if self.output_attentions:
if output_attentions:
outputs = outputs + (attn_prob,)
return outputs
......@@ -431,7 +480,17 @@ class XLNetLayer(nn.Module):
self.dropout = nn.Dropout(config.dropout)
def forward(
self, output_h, output_g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None
self,
output_h,
output_g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems=None,
target_mapping=None,
head_mask=None,
output_attentions=False,
):
outputs = self.rel_attn(
output_h,
......@@ -443,6 +502,7 @@ class XLNetLayer(nn.Module):
mems=mems,
target_mapping=target_mapping,
head_mask=head_mask,
output_attentions=output_attentions,
)
output_h, output_g = outputs[:2]
......@@ -568,7 +628,6 @@ XLNET_INPUTS_DOCSTRING = r"""
class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.mem_len = config.mem_len
......@@ -701,6 +760,7 @@ class XLNetModel(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
):
r"""
Return:
......@@ -717,7 +777,7 @@ class XLNetModel(XLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -738,6 +798,8 @@ class XLNetModel(XLNetPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
......@@ -883,9 +945,10 @@ class XLNetModel(XLNetPreTrainedModel):
mems=mems[i],
target_mapping=target_mapping,
head_mask=head_mask[i],
output_attentions=output_attentions,
)
output_h, output_g = outputs[:2]
if self.output_attentions:
if output_attentions:
attentions.append(outputs[2])
# Add last hidden state
......@@ -906,7 +969,7 @@ class XLNetModel(XLNetPreTrainedModel):
else:
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
outputs = outputs + (hidden_states,)
if self.output_attentions:
if output_attentions:
if target_mapping is not None:
# when target_mapping is provided, there are 2-tuple of attentions
attentions = tuple(
......@@ -985,6 +1048,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
......@@ -1011,7 +1075,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1060,6 +1124,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
logits = self.lm_loss(transformer_outputs[0])
......@@ -1105,6 +1170,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`)
......@@ -1128,7 +1194,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1160,6 +1226,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
output = transformer_outputs[0]
......@@ -1210,6 +1277,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -1232,7 +1300,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1266,6 +1334,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
sequence_output = outputs[0]
......@@ -1319,6 +1388,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
inputs_embeds=None,
use_cache=True,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -1343,7 +1413,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1384,6 +1454,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
output = transformer_outputs[0]
......@@ -1433,6 +1504,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
use_cache=True,
start_positions=None,
end_positions=None,
output_attentions=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -1461,7 +1533,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1496,6 +1568,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
sequence_output = outputs[0]
......@@ -1562,6 +1635,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
is_impossible=None,
cls_index=None,
p_mask=None,
output_attentions=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -1603,7 +1677,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......@@ -1636,6 +1710,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
......@@ -288,7 +288,7 @@ class TFXxxModel(TFXxxPreTrainedModel):
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -329,7 +329,7 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -378,7 +378,7 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -433,7 +433,7 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -490,7 +490,7 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......
......@@ -285,7 +285,7 @@ class XxxModel(XxxPreTrainedModel):
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -403,7 +403,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -483,7 +483,7 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -569,7 +569,7 @@ class XxxForTokenClassification(XxxPreTrainedModel):
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......@@ -663,7 +663,7 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......
......@@ -130,7 +130,7 @@ class ModelTesterMixin:
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config)
model.to(torch_device)
......@@ -138,7 +138,18 @@ class ModelTesterMixin:
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
......@@ -172,7 +183,7 @@ class ModelTesterMixin:
)
# Check attention is always last and order is fine
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
model = model_class(config)
model.to(torch_device)
......@@ -180,7 +191,6 @@ class ModelTesterMixin:
with torch.no_grad():
outputs = model(**inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self_attentions = outputs[-1]
......@@ -203,7 +213,6 @@ class ModelTesterMixin:
def test_torchscript_output_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
self._create_and_check_torchscript(config, inputs_dict)
......@@ -270,7 +279,7 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
global_rng.seed()
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in self.all_model_classes:
......@@ -326,7 +335,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config=config)
model.to(torch_device)
......@@ -355,7 +364,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config=config)
model.to(torch_device)
......@@ -388,7 +397,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
heads_to_prune = {
......@@ -419,7 +428,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
heads_to_prune = {0: [0], 1: [1, 2]}
......@@ -471,14 +480,12 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config.output_hidden_states = True
config.output_attentions = False
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
hidden_states = outputs[-1]
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
......@@ -838,7 +845,6 @@ class ModelUtilsTest(unittest.TestCase):
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
......
......@@ -296,7 +296,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False # head masking is not supported
test_torchscript = False
all_model_classes = (LongformerForMaskedLM, LongformerModel) if is_torch_available() else ()
all_model_classes = (LongformerModel, LongformerForMaskedLM,) if is_torch_available() else ()
def setUp(self):
self.model_tester = LongformerModelTester(self)
......
......@@ -314,12 +314,11 @@ class TFModelTesterMixin:
)
for model_class in self.all_model_classes:
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config)
outputs = model(inputs_dict)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
......@@ -331,7 +330,6 @@ class TFModelTesterMixin:
if self.is_encoder_decoder:
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs[(out_len // 2) - 1]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
......@@ -339,13 +337,25 @@ class TFModelTesterMixin:
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# Check attention is always last and order is fine
# Check that output attentions can also be changed via the config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(inputs_dict)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
model = model_class(config)
outputs = model(inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
attentions = [t.numpy() for t in outputs[-1]]
......@@ -360,11 +370,9 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
config.output_hidden_states = True
config.output_attentions = False
model = model_class(config)
outputs = model(inputs_dict)
hidden_states = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual(
......
......@@ -238,7 +238,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
model.to(torch_device)
model.eval()
_, _, attentions = model(input_ids_1, target_mapping=target_mapping)
_, _, attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)
self.parent.assertEqual(len(attentions), config.n_layer)
self.parent.assertIsInstance(attentions[0], tuple)
......@@ -483,7 +483,6 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].output_attentions = True
self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs)
def test_xlnet_lm_head(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment