Unverified Commit 6e603cb7 authored by Bharat Raghunathan's avatar Bharat Raghunathan Committed by GitHub
Browse files

[All models] Extend config.output_attentions with output_attentions function arguments (#4538)



* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions``

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* Fix further regressions in tests relating to `output_attentions`

Ensure proper propagation of `output_attentions` as a function parameter
to all model subclasses

* Fix more regressions in `test_output_attentions`

* Fix issues with BertEncoder

* Rename related variables to `output_attentions`

* fix pytorch tests

* fix bert and gpt2 tf

* Fix most TF tests for `test_output_attentions`

* Fix linter errors and more TF tests

* fix conflicts

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix pytorch tests

* fix conflicts

* fix conflicts

* Fix linter errors and more TF tests

* fix tf tests

* make style

* fix isort

* improve output_attentions

* improve tensorflow
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f90bc44d
...@@ -44,8 +44,6 @@ class PretrainedConfig(object): ...@@ -44,8 +44,6 @@ class PretrainedConfig(object):
Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
num_labels (:obj:`int`, `optional`, defaults to `2`): num_labels (:obj:`int`, `optional`, defaults to `2`):
Number of classes to use when the model is a classification model (sequences/tokens) Number of classes to use when the model is a classification model (sequences/tokens)
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
Should the model returns attentions weights.
output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`): output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`):
Should the model returns all hidden-states. Should the model returns all hidden-states.
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
...@@ -55,8 +53,8 @@ class PretrainedConfig(object): ...@@ -55,8 +53,8 @@ class PretrainedConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Attributes with defaults # Attributes with defaults
self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False) self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False) self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
......
...@@ -187,7 +187,6 @@ class AlbertAttention(BertSelfAttention): ...@@ -187,7 +187,6 @@ class AlbertAttention(BertSelfAttention):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = config.hidden_size // config.num_attention_heads
...@@ -214,7 +213,7 @@ class AlbertAttention(BertSelfAttention): ...@@ -214,7 +213,7 @@ class AlbertAttention(BertSelfAttention):
self.all_head_size = self.attention_head_size * self.num_attention_heads self.all_head_size = self.attention_head_size * self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_ids, attention_mask=None, head_mask=None): def forward(self, input_ids, attention_mask=None, head_mask=None, output_attentions=False):
mixed_query_layer = self.query(input_ids) mixed_query_layer = self.query(input_ids)
mixed_key_layer = self.key(input_ids) mixed_key_layer = self.key(input_ids)
mixed_value_layer = self.value(input_ids) mixed_value_layer = self.value(input_ids)
...@@ -256,7 +255,7 @@ class AlbertAttention(BertSelfAttention): ...@@ -256,7 +255,7 @@ class AlbertAttention(BertSelfAttention):
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer_dropout = self.dropout(projected_context_layer) projected_context_layer_dropout = self.dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout) layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout)
return (layernormed_context_layer, attention_probs) if self.output_attentions else (layernormed_context_layer,) return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
class AlbertLayer(nn.Module): class AlbertLayer(nn.Module):
...@@ -270,8 +269,8 @@ class AlbertLayer(nn.Module): ...@@ -270,8 +269,8 @@ class AlbertLayer(nn.Module):
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
attention_output = self.attention(hidden_states, attention_mask, head_mask) attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = self.ffn(attention_output[0]) ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output) ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output) ffn_output = self.ffn_output(ffn_output)
...@@ -284,19 +283,18 @@ class AlbertLayerGroup(nn.Module): ...@@ -284,19 +283,18 @@ class AlbertLayerGroup(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)]) self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
layer_hidden_states = () layer_hidden_states = ()
layer_attentions = () layer_attentions = ()
for layer_index, albert_layer in enumerate(self.albert_layers): for layer_index, albert_layer in enumerate(self.albert_layers):
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index]) layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
hidden_states = layer_output[0] hidden_states = layer_output[0]
if self.output_attentions: if output_attentions:
layer_attentions = layer_attentions + (layer_output[1],) layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states: if self.output_hidden_states:
...@@ -305,7 +303,7 @@ class AlbertLayerGroup(nn.Module): ...@@ -305,7 +303,7 @@ class AlbertLayerGroup(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (layer_hidden_states,) outputs = outputs + (layer_hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (layer_attentions,) outputs = outputs + (layer_attentions,)
return outputs # last-layer hidden state, (layer hidden states), (layer attentions) return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
...@@ -315,12 +313,11 @@ class AlbertTransformer(nn.Module): ...@@ -315,12 +313,11 @@ class AlbertTransformer(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size) self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]) self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
hidden_states = self.embedding_hidden_mapping_in(hidden_states) hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = () all_attentions = ()
...@@ -339,10 +336,11 @@ class AlbertTransformer(nn.Module): ...@@ -339,10 +336,11 @@ class AlbertTransformer(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group], head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
output_attentions,
) )
hidden_states = layer_group_output[0] hidden_states = layer_group_output[0]
if self.output_attentions: if output_attentions:
all_attentions = all_attentions + layer_group_output[-1] all_attentions = all_attentions + layer_group_output[-1]
if self.output_hidden_states: if self.output_hidden_states:
...@@ -351,7 +349,7 @@ class AlbertTransformer(nn.Module): ...@@ -351,7 +349,7 @@ class AlbertTransformer(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
...@@ -488,6 +486,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -488,6 +486,7 @@ class AlbertModel(AlbertPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -508,7 +507,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -508,7 +507,7 @@ class AlbertModel(AlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -528,6 +527,8 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -528,6 +527,8 @@ class AlbertModel(AlbertPreTrainedModel):
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -552,7 +553,9 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -552,7 +553,9 @@ class AlbertModel(AlbertPreTrainedModel):
embedding_output = self.embeddings( embedding_output = self.embeddings(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
) )
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask) encoder_outputs = self.encoder(
embedding_output, extended_attention_mask, head_mask=head_mask, output_attentions=output_attentions,
)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
...@@ -597,7 +600,8 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -597,7 +600,8 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
sentence_order_label=None, sentence_order_label=None,
**kwargs output_attentions=None,
**kwargs,
): ):
r""" r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
...@@ -627,7 +631,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -627,7 +631,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -665,6 +669,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -665,6 +669,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
...@@ -750,6 +755,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -750,6 +755,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -772,7 +778,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -772,7 +778,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -806,6 +812,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -806,6 +812,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_outputs = outputs[0] sequence_outputs = outputs[0]
...@@ -846,6 +853,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -846,6 +853,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -865,7 +873,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -865,7 +873,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -893,6 +901,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -893,6 +901,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -941,6 +950,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -941,6 +950,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -958,7 +968,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -958,7 +968,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -988,6 +998,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -988,6 +998,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1038,6 +1049,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1038,6 +1049,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
output_attentions=None,
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1062,7 +1074,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1062,7 +1074,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1092,6 +1104,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1092,6 +1104,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -183,7 +183,6 @@ class EncoderLayer(nn.Module): ...@@ -183,7 +183,6 @@ class EncoderLayer(nn.Module):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.output_attentions = config.output_attentions
self.self_attn = SelfAttention( self.self_attn = SelfAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
) )
...@@ -196,7 +195,7 @@ class EncoderLayer(nn.Module): ...@@ -196,7 +195,7 @@ class EncoderLayer(nn.Module):
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, x, encoder_padding_mask): def forward(self, x, encoder_padding_mask, output_attentions=False):
""" """
Args: Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -212,7 +211,7 @@ class EncoderLayer(nn.Module): ...@@ -212,7 +211,7 @@ class EncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn( x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -246,7 +245,6 @@ class BartEncoder(nn.Module): ...@@ -246,7 +245,6 @@ class BartEncoder(nn.Module):
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop self.layerdrop = config.encoder_layerdrop
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
...@@ -268,9 +266,7 @@ class BartEncoder(nn.Module): ...@@ -268,9 +266,7 @@ class BartEncoder(nn.Module):
# mbart has one extra layer_norm # mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
def forward( def forward(self, input_ids, attention_mask=None, output_attentions=False):
self, input_ids, attention_mask=None,
):
""" """
Args: Args:
input_ids (LongTensor): tokens in the source language of shape input_ids (LongTensor): tokens in the source language of shape
...@@ -308,9 +304,9 @@ class BartEncoder(nn.Module): ...@@ -308,9 +304,9 @@ class BartEncoder(nn.Module):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None attn = None
else: else:
x, attn = encoder_layer(x, attention_mask) x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
if self.output_attentions: if output_attentions:
all_attentions.append(attn) all_attentions.append(attn)
if self.layer_norm: if self.layer_norm:
...@@ -329,7 +325,6 @@ class DecoderLayer(nn.Module): ...@@ -329,7 +325,6 @@ class DecoderLayer(nn.Module):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.output_attentions = config.output_attentions
self.self_attn = SelfAttention( self.self_attn = SelfAttention(
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
) )
...@@ -358,6 +353,7 @@ class DecoderLayer(nn.Module): ...@@ -358,6 +353,7 @@ class DecoderLayer(nn.Module):
layer_state=None, layer_state=None,
causal_mask=None, causal_mask=None,
decoder_padding_mask=None, decoder_padding_mask=None,
output_attentions=False,
): ):
residual = x residual = x
...@@ -373,7 +369,7 @@ class DecoderLayer(nn.Module): ...@@ -373,7 +369,7 @@ class DecoderLayer(nn.Module):
layer_state=layer_state, # adds keys to layer state layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask, key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask, attn_mask=causal_mask,
need_weights=self.output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -425,7 +421,6 @@ class BartDecoder(nn.Module): ...@@ -425,7 +421,6 @@ class BartDecoder(nn.Module):
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding): def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
super().__init__() super().__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
...@@ -456,7 +451,8 @@ class BartDecoder(nn.Module): ...@@ -456,7 +451,8 @@ class BartDecoder(nn.Module):
decoder_causal_mask, decoder_causal_mask,
decoder_cached_states=None, decoder_cached_states=None,
use_cache=False, use_cache=False,
**unused output_attentions=False,
**unused,
): ):
""" """
Includes several features from "Jointly Learning to Align and Includes several features from "Jointly Learning to Align and
...@@ -518,6 +514,7 @@ class BartDecoder(nn.Module): ...@@ -518,6 +514,7 @@ class BartDecoder(nn.Module):
decoder_padding_mask=decoder_padding_mask, decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state, layer_state=layer_state,
causal_mask=decoder_causal_mask, causal_mask=decoder_causal_mask,
output_attentions=output_attentions,
) )
if use_cache: if use_cache:
...@@ -525,7 +522,7 @@ class BartDecoder(nn.Module): ...@@ -525,7 +522,7 @@ class BartDecoder(nn.Module):
if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
x = self.layer_norm(x) x = self.layer_norm(x)
if self.output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
...@@ -583,7 +580,7 @@ class SelfAttention(nn.Module): ...@@ -583,7 +580,7 @@ class SelfAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None, layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
need_weights=False, output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel""" """Input shape: Time(SeqLen) x Batch x Channel"""
static_kv: bool = self.encoder_decoder_attention static_kv: bool = self.encoder_decoder_attention
...@@ -655,7 +652,7 @@ class SelfAttention(nn.Module): ...@@ -655,7 +652,7 @@ class SelfAttention(nn.Module):
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if need_weights: if output_attentions:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
else: else:
attn_weights = None attn_weights = None
...@@ -797,7 +794,6 @@ def _get_shape(t): ...@@ -797,7 +794,6 @@ def _get_shape(t):
class BartModel(PretrainedBartModel): class BartModel(PretrainedBartModel):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
padding_idx, vocab_size = config.pad_token_id, config.vocab_size padding_idx, vocab_size = config.pad_token_id, config.vocab_size
...@@ -818,7 +814,9 @@ class BartModel(PretrainedBartModel): ...@@ -818,7 +814,9 @@ class BartModel(PretrainedBartModel):
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
use_cache=False, use_cache=False,
output_attentions=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# make masks if user doesn't supply # make masks if user doesn't supply
if not use_cache: if not use_cache:
...@@ -833,8 +831,11 @@ class BartModel(PretrainedBartModel): ...@@ -833,8 +831,11 @@ class BartModel(PretrainedBartModel):
decoder_padding_mask, causal_mask = None, None decoder_padding_mask, causal_mask = None, None
assert decoder_input_ids is not None assert decoder_input_ids is not None
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) encoder_outputs = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions,
)
assert isinstance(encoder_outputs, tuple) assert isinstance(encoder_outputs, tuple)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
...@@ -844,8 +845,10 @@ class BartModel(PretrainedBartModel): ...@@ -844,8 +845,10 @@ class BartModel(PretrainedBartModel):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask, decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states, decoder_cached_states=decoder_cached_states,
output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
) )
# Attention and hidden_states will be [] or None if they aren't needed # Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs) decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
assert isinstance(decoder_outputs[0], torch.Tensor) assert isinstance(decoder_outputs[0], torch.Tensor)
...@@ -903,7 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -903,7 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_cached_states=None, decoder_cached_states=None,
labels=None, labels=None,
use_cache=False, use_cache=False,
**unused output_attentions=None,
**unused,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -924,7 +928,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -924,7 +928,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -961,6 +965,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -961,6 +965,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states, decoder_cached_states=decoder_cached_states,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
...@@ -1055,6 +1060,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1055,6 +1060,7 @@ class BartForSequenceClassification(PretrainedBartModel):
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1072,7 +1078,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1072,7 +1078,7 @@ class BartForSequenceClassification(PretrainedBartModel):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the Attentions weights after the attention softmax, used to compute the weighted average in the
self-attention self-attention
...@@ -1098,6 +1104,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1098,6 +1104,7 @@ class BartForSequenceClassification(PretrainedBartModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
output_attentions=output_attentions,
) )
x = outputs[0] # last hidden state x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
......
...@@ -190,7 +190,6 @@ class BertSelfAttention(nn.Module): ...@@ -190,7 +190,6 @@ class BertSelfAttention(nn.Module):
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads) "heads (%d)" % (config.hidden_size, config.num_attention_heads)
) )
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
...@@ -214,6 +213,7 @@ class BertSelfAttention(nn.Module): ...@@ -214,6 +213,7 @@ class BertSelfAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
output_attentions=False,
): ):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -256,7 +256,7 @@ class BertSelfAttention(nn.Module): ...@@ -256,7 +256,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs return outputs
...@@ -306,9 +306,10 @@ class BertAttention(nn.Module): ...@@ -306,9 +306,10 @@ class BertAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
output_attentions=False,
): ):
self_outputs = self.self( self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
...@@ -361,14 +362,22 @@ class BertLayer(nn.Module): ...@@ -361,14 +362,22 @@ class BertLayer(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
output_attentions=False,
): ):
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
...@@ -382,7 +391,6 @@ class BertLayer(nn.Module): ...@@ -382,7 +391,6 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
...@@ -393,6 +401,7 @@ class BertEncoder(nn.Module): ...@@ -393,6 +401,7 @@ class BertEncoder(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
output_attentions=False,
): ):
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
...@@ -401,11 +410,16 @@ class BertEncoder(nn.Module): ...@@ -401,11 +410,16 @@ class BertEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer # Add last layer
...@@ -415,7 +429,7 @@ class BertEncoder(nn.Module): ...@@ -415,7 +429,7 @@ class BertEncoder(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
...@@ -639,6 +653,7 @@ class BertModel(BertPreTrainedModel): ...@@ -639,6 +653,7 @@ class BertModel(BertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -659,7 +674,7 @@ class BertModel(BertPreTrainedModel): ...@@ -659,7 +674,7 @@ class BertModel(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -680,6 +695,7 @@ class BertModel(BertPreTrainedModel): ...@@ -680,6 +695,7 @@ class BertModel(BertPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -728,6 +744,7 @@ class BertModel(BertPreTrainedModel): ...@@ -728,6 +744,7 @@ class BertModel(BertPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
...@@ -766,6 +783,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -766,6 +783,7 @@ class BertForPreTraining(BertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
next_sentence_label=None, next_sentence_label=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -796,7 +814,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -796,7 +814,7 @@ class BertForPreTraining(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -833,6 +851,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -833,6 +851,7 @@ class BertForPreTraining(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
...@@ -879,6 +898,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -879,6 +898,7 @@ class BertForMaskedLM(BertPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
lm_labels=None, lm_labels=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -908,7 +928,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -908,7 +928,7 @@ class BertForMaskedLM(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -946,6 +966,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -946,6 +966,7 @@ class BertForMaskedLM(BertPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1019,6 +1040,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1019,6 +1040,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
next_sentence_label=None, next_sentence_label=None,
output_attentions=None,
): ):
r""" r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1038,7 +1060,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1038,7 +1060,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1068,6 +1090,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1068,6 +1090,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1109,6 +1132,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1109,6 +1132,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1128,7 +1152,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1128,7 +1152,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1158,6 +1182,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1158,6 +1182,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1205,6 +1230,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1205,6 +1230,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1225,7 +1251,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1225,7 +1251,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1265,6 +1291,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1265,6 +1291,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1309,6 +1336,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1309,6 +1336,7 @@ class BertForTokenClassification(BertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -1326,7 +1354,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1326,7 +1354,7 @@ class BertForTokenClassification(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1356,6 +1384,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1356,6 +1384,7 @@ class BertForTokenClassification(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1407,6 +1436,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1407,6 +1436,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
output_attentions=None,
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1431,7 +1461,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1431,7 +1461,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1465,6 +1495,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1465,6 +1495,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -83,9 +83,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N ...@@ -83,9 +83,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
class MultiHeadAttention(torch.nn.Module): class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model_size, num_heads, output_attentions=False): def __init__(self, d_model_size, num_heads):
super().__init__() super().__init__()
self.output_attentions = output_attentions
self.num_heads = num_heads self.num_heads = num_heads
self.d_model_size = d_model_size self.d_model_size = d_model_size
...@@ -101,7 +100,18 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -101,7 +100,18 @@ class MultiHeadAttention(torch.nn.Module):
x = x.reshape(batch_size, -1, self.num_heads, self.depth) x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3]) return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False): def forward(
self,
v,
k,
q,
mask,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
batch_size = q.shape[0] batch_size = q.shape[0]
q = self.Wq(q) q = self.Wq(q)
...@@ -128,7 +138,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -128,7 +138,7 @@ class MultiHeadAttention(torch.nn.Module):
output = self.dense(original_size_attention) output = self.dense(original_size_attention)
outputs = (output, present) outputs = (output, present)
if self.output_attentions: if output_attentions:
outputs = outputs + (attn,) outputs = outputs + (attn,)
return outputs return outputs
...@@ -138,10 +148,10 @@ def point_wise_feed_forward_network(d_model_size, dff): ...@@ -138,10 +148,10 @@ def point_wise_feed_forward_network(d_model_size, dff):
class EncoderLayer(torch.nn.Module): class EncoderLayer(torch.nn.Module):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False): def __init__(self, d_model_size, num_heads, dff, rate=0.1):
super().__init__() super().__init__()
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions) self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads)
self.ffn = point_wise_feed_forward_network(d_model_size, dff) self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6) self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
...@@ -150,7 +160,9 @@ class EncoderLayer(torch.nn.Module): ...@@ -150,7 +160,9 @@ class EncoderLayer(torch.nn.Module):
self.dropout1 = torch.nn.Dropout(rate) self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate) self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False): def forward(
self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
):
normed = self.layernorm1(x) normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention( attn_outputs = self.multi_head_attention(
normed, normed,
...@@ -161,6 +173,7 @@ class EncoderLayer(torch.nn.Module): ...@@ -161,6 +173,7 @@ class EncoderLayer(torch.nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output) attn_output = self.dropout1(attn_output)
...@@ -264,7 +277,6 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -264,7 +277,6 @@ class CTRLModel(CTRLPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
...@@ -275,10 +287,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -275,10 +287,7 @@ class CTRLModel(CTRLPreTrainedModel):
self.dropout = nn.Dropout(config.embd_pdrop) self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)]
EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop, config.output_attentions)
for _ in range(config.n_layer)
]
) )
self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
...@@ -308,6 +317,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -308,6 +317,7 @@ class CTRLModel(CTRLPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -322,7 +332,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -322,7 +332,7 @@ class CTRLModel(CTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -343,6 +353,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -343,6 +353,7 @@ class CTRLModel(CTRLPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -424,12 +435,13 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -424,12 +435,13 @@ class CTRLModel(CTRLPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache is True: if use_cache is True:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if output_attentions:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states) hidden_states = self.layernorm(hidden_states)
...@@ -442,7 +454,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -442,7 +454,7 @@ class CTRLModel(CTRLPreTrainedModel):
outputs = outputs + (presents,) outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
...@@ -485,6 +497,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -485,6 +497,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=True, use_cache=True,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -508,7 +521,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -508,7 +521,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -537,6 +550,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -537,6 +550,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -105,7 +105,6 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -105,7 +105,6 @@ class MultiHeadSelfAttention(nn.Module):
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.dim = config.dim self.dim = config.dim
self.dropout = nn.Dropout(p=config.attention_dropout) self.dropout = nn.Dropout(p=config.attention_dropout)
self.output_attentions = config.output_attentions
assert self.dim % self.n_heads == 0 assert self.dim % self.n_heads == 0
...@@ -131,7 +130,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -131,7 +130,7 @@ class MultiHeadSelfAttention(nn.Module):
self.dim = attention_head_size * self.n_heads self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, mask, head_mask=None): def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
""" """
Parameters Parameters
---------- ----------
...@@ -184,7 +183,7 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -184,7 +183,7 @@ class MultiHeadSelfAttention(nn.Module):
context = unshape(context) # (bs, q_length, dim) context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim)
if self.output_attentions: if output_attentions:
return (context, weights) return (context, weights)
else: else:
return (context,) return (context,)
...@@ -213,8 +212,6 @@ class TransformerBlock(nn.Module): ...@@ -213,8 +212,6 @@ class TransformerBlock(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.output_attentions = config.output_attentions
assert config.dim % config.n_heads == 0 assert config.dim % config.n_heads == 0
self.attention = MultiHeadSelfAttention(config) self.attention = MultiHeadSelfAttention(config)
...@@ -223,7 +220,7 @@ class TransformerBlock(nn.Module): ...@@ -223,7 +220,7 @@ class TransformerBlock(nn.Module):
self.ffn = FFN(config) self.ffn = FFN(config)
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
def forward(self, x, attn_mask=None, head_mask=None): def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
""" """
Parameters Parameters
---------- ----------
...@@ -238,8 +235,10 @@ class TransformerBlock(nn.Module): ...@@ -238,8 +235,10 @@ class TransformerBlock(nn.Module):
The output of the transformer block contextualization. The output of the transformer block contextualization.
""" """
# Self-Attention # Self-Attention
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask) sa_output = self.attention(
if self.output_attentions: query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask, output_attentions=output_attentions,
)
if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(sa_output) == tuple assert type(sa_output) == tuple
...@@ -251,7 +250,7 @@ class TransformerBlock(nn.Module): ...@@ -251,7 +250,7 @@ class TransformerBlock(nn.Module):
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output,) output = (ffn_output,)
if self.output_attentions: if output_attentions:
output = (sa_weights,) + output output = (sa_weights,) + output
return output return output
...@@ -260,13 +259,12 @@ class Transformer(nn.Module): ...@@ -260,13 +259,12 @@ class Transformer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.n_layers = config.n_layers self.n_layers = config.n_layers
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
layer = TransformerBlock(config) layer = TransformerBlock(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
def forward(self, x, attn_mask=None, head_mask=None): def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
""" """
Parameters Parameters
---------- ----------
...@@ -294,10 +292,12 @@ class Transformer(nn.Module): ...@@ -294,10 +292,12 @@ class Transformer(nn.Module):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module(x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i]) layer_outputs = layer_module(
x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
)
hidden_state = layer_outputs[-1] hidden_state = layer_outputs[-1]
if self.output_attentions: if output_attentions:
assert len(layer_outputs) == 2 assert len(layer_outputs) == 2
attentions = layer_outputs[0] attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,) all_attentions = all_attentions + (attentions,)
...@@ -311,7 +311,7 @@ class Transformer(nn.Module): ...@@ -311,7 +311,7 @@ class Transformer(nn.Module):
outputs = (hidden_state,) outputs = (hidden_state,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
...@@ -408,7 +408,9 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -408,7 +408,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.transformer.layer[layer].attention.prune_heads(heads) self.transformer.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None): def forward(
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None,
):
r""" r"""
Return: Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
...@@ -419,7 +421,7 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -419,7 +421,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -440,6 +442,8 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -440,6 +442,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -459,7 +463,9 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -459,7 +463,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim) inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=inputs_embeds, attn_mask=attention_mask, head_mask=head_mask) tfmr_output = self.transformer(
x=inputs_embeds, attn_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions,
)
hidden_state = tfmr_output[0] hidden_state = tfmr_output[0]
output = (hidden_state,) + tfmr_output[1:] output = (hidden_state,) + tfmr_output[1:]
...@@ -472,7 +478,6 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -472,7 +478,6 @@ class DistilBertModel(DistilBertPreTrainedModel):
class DistilBertForMaskedLM(DistilBertPreTrainedModel): class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.distilbert = DistilBertModel(config) self.distilbert = DistilBertModel(config)
...@@ -488,7 +493,16 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -488,7 +493,16 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
return self.vocab_projector return self.vocab_projector
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, **kwargs): def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
**kwargs
):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
...@@ -509,7 +523,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -509,7 +523,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -537,7 +551,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -537,7 +551,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
dlbrt_output = self.distilbert( dlbrt_output = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
hidden_states = dlbrt_output[0] # (bs, seq_length, dim) hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
...@@ -571,7 +589,15 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -571,7 +589,15 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None): def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss. Labels for computing the sequence classification/regression loss.
...@@ -590,7 +616,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -590,7 +616,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -611,7 +637,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -611,7 +637,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
""" """
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
...@@ -658,6 +688,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -658,6 +688,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
output_attentions=None,
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -682,7 +713,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -682,7 +713,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -704,7 +735,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -704,7 +735,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
""" """
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
hidden_states = distilbert_output[0] # (bs, max_query_len, dim) hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
...@@ -752,7 +787,15 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -752,7 +787,15 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None): def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss. Labels for computing the token classification loss.
...@@ -769,7 +812,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -769,7 +812,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -791,7 +834,11 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -791,7 +834,11 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
""" """
outputs = self.distilbert( outputs = self.distilbert(
input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -269,6 +269,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -269,6 +269,7 @@ class ElectraModel(ElectraPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -280,7 +281,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -280,7 +281,7 @@ class ElectraModel(ElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -301,6 +302,9 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -301,6 +302,9 @@ class ElectraModel(ElectraPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -327,7 +331,12 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -327,7 +331,12 @@ class ElectraModel(ElectraPreTrainedModel):
if hasattr(self, "embeddings_project"): if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states) hidden_states = self.embeddings_project(hidden_states)
hidden_states = self.encoder(hidden_states, attention_mask=extended_attention_mask, head_mask=head_mask) hidden_states = self.encoder(
hidden_states,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
)
return hidden_states return hidden_states
...@@ -375,6 +384,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -375,6 +384,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -394,7 +404,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -394,7 +404,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -417,7 +427,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -417,7 +427,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
""" """
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions
) )
sequence_output = discriminator_hidden_states[0] sequence_output = discriminator_hidden_states[0]
...@@ -464,6 +474,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -464,6 +474,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
...@@ -483,7 +494,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -483,7 +494,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -507,7 +518,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -507,7 +518,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
""" """
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions,
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
...@@ -563,6 +574,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): ...@@ -563,6 +574,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -585,7 +597,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): ...@@ -585,7 +597,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -615,7 +627,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): ...@@ -615,7 +627,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
generator_hidden_states = self.electra( generator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions,
) )
generator_sequence_output = generator_hidden_states[0] generator_sequence_output = generator_hidden_states[0]
...@@ -661,6 +673,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): ...@@ -661,6 +673,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -678,7 +691,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): ...@@ -678,7 +691,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -702,7 +715,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): ...@@ -702,7 +715,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
""" """
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions,
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
......
...@@ -128,6 +128,7 @@ class FlaubertModel(XLMModel): ...@@ -128,6 +128,7 @@ class FlaubertModel(XLMModel):
cache=None, cache=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -139,7 +140,7 @@ class FlaubertModel(XLMModel): ...@@ -139,7 +140,7 @@ class FlaubertModel(XLMModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -158,6 +159,8 @@ class FlaubertModel(XLMModel): ...@@ -158,6 +159,8 @@ class FlaubertModel(XLMModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
if input_ids is not None: if input_ids is not None:
bs, slen = input_ids.size() bs, slen = input_ids.size()
...@@ -240,9 +243,11 @@ class FlaubertModel(XLMModel): ...@@ -240,9 +243,11 @@ class FlaubertModel(XLMModel):
# self attention # self attention
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i]) attn_outputs = self.attentions[i](
tensor, attn_mask, cache=cache, head_mask=head_mask[i], output_attentions=output_attentions,
)
attn = attn_outputs[0] attn = attn_outputs[0]
if self.output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training) attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn tensor = tensor + attn
...@@ -251,7 +256,7 @@ class FlaubertModel(XLMModel): ...@@ -251,7 +256,7 @@ class FlaubertModel(XLMModel):
tensor_normalized = self.layer_norm1[i](tensor) tensor_normalized = self.layer_norm1[i](tensor)
attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache, head_mask=head_mask[i]) attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache, head_mask=head_mask[i])
attn = attn_outputs[0] attn = attn_outputs[0]
if self.output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training) attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn tensor = tensor + attn
...@@ -287,7 +292,7 @@ class FlaubertModel(XLMModel): ...@@ -287,7 +292,7 @@ class FlaubertModel(XLMModel):
outputs = (tensor,) outputs = (tensor,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions) return outputs # outputs, (hidden_states), (attentions)
......
...@@ -106,7 +106,6 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): ...@@ -106,7 +106,6 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False):
super().__init__() super().__init__()
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -142,7 +141,7 @@ class Attention(nn.Module): ...@@ -142,7 +141,7 @@ class Attention(nn.Module):
self.n_head = self.n_head - len(heads) self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, attention_mask=None, head_mask=None): def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
w = torch.matmul(q, k) w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / (float(v.size(-1)) ** 0.5) w = w / (float(v.size(-1)) ** 0.5)
...@@ -162,7 +161,7 @@ class Attention(nn.Module): ...@@ -162,7 +161,7 @@ class Attention(nn.Module):
w = w * head_mask w = w * head_mask
outputs = [torch.matmul(w, v)] outputs = [torch.matmul(w, v)]
if self.output_attentions: if output_attentions:
outputs.append(w) outputs.append(w)
return outputs return outputs
...@@ -179,7 +178,9 @@ class Attention(nn.Module): ...@@ -179,7 +178,9 @@ class Attention(nn.Module):
else: else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False): def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
):
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2) query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query) query = self.split_heads(query)
...@@ -195,7 +196,7 @@ class Attention(nn.Module): ...@@ -195,7 +196,7 @@ class Attention(nn.Module):
else: else:
present = (None,) present = (None,)
attn_outputs = self._attn(query, key, value, attention_mask, head_mask) attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
a = attn_outputs[0] a = attn_outputs[0]
a = self.merge_heads(a) a = self.merge_heads(a)
...@@ -230,13 +231,16 @@ class Block(nn.Module): ...@@ -230,13 +231,16 @@ class Block(nn.Module):
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False): def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
):
output_attn = self.attn( output_attn = self.attn(
self.ln_1(x), self.ln_1(x),
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
a = output_attn[0] # output_attn: a, present, (attentions) a = output_attn[0] # output_attn: a, present, (attentions)
...@@ -342,7 +346,6 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -342,7 +346,6 @@ class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
...@@ -376,6 +379,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -376,6 +379,7 @@ class GPT2Model(GPT2PreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -391,7 +395,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -391,7 +395,7 @@ class GPT2Model(GPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -410,6 +414,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -410,6 +414,7 @@ class GPT2Model(GPT2PreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -488,13 +493,14 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -488,13 +493,14 @@ class GPT2Model(GPT2PreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache is True: if use_cache is True:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if output_attentions:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -509,7 +515,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -509,7 +515,7 @@ class GPT2Model(GPT2PreTrainedModel):
outputs = outputs + (presents,) outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
...@@ -552,6 +558,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -552,6 +558,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=True, use_cache=True,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -575,7 +582,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -575,7 +582,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -604,6 +611,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -604,6 +611,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -657,6 +665,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -657,6 +665,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
labels=None, labels=None,
mc_labels=None, mc_labels=None,
use_cache=True, use_cache=True,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -694,7 +703,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -694,7 +703,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -742,6 +751,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -742,6 +751,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -87,7 +87,6 @@ class LongformerSelfAttention(nn.Module): ...@@ -87,7 +87,6 @@ class LongformerSelfAttention(nn.Module):
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads) "heads (%d)" % (config.hidden_size, config.num_attention_heads)
) )
self.output_attentions = config.output_attentions
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads) self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -242,6 +241,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -242,6 +241,7 @@ class LongformerSelfAttention(nn.Module):
head_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
output_attentions=False,
): ):
""" """
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
...@@ -415,7 +415,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -415,7 +415,7 @@ class LongformerSelfAttention(nn.Module):
) )
context_layer = attn.transpose(0, 1) context_layer = attn.transpose(0, 1)
if self.output_attentions: if output_attentions:
if extra_attention_mask is not None: if extra_attention_mask is not None:
# With global attention, return global attention probabilities only # With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
...@@ -429,7 +429,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -429,7 +429,7 @@ class LongformerSelfAttention(nn.Module):
# batch_size x num_heads x sequence_length x window_size # batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours # which is the attention weights of every token attending to its neighbours
attn_weights = attn_weights.permute(0, 2, 1, 3) attn_weights = attn_weights.permute(0, 2, 1, 3)
outputs = (context_layer, attn_weights) if self.output_attentions else (context_layer,) outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
return outputs return outputs
...@@ -584,6 +584,7 @@ class LongformerModel(RobertaModel): ...@@ -584,6 +584,7 @@ class LongformerModel(RobertaModel):
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
): ):
r""" r"""
...@@ -596,7 +597,7 @@ class LongformerModel(RobertaModel): ...@@ -596,7 +597,7 @@ class LongformerModel(RobertaModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -623,6 +624,8 @@ class LongformerModel(RobertaModel): ...@@ -623,6 +624,8 @@ class LongformerModel(RobertaModel):
sequence_output, pooled_output = model(input_ids, attention_mask=attention_mask) sequence_output, pooled_output = model(input_ids, attention_mask=attention_mask)
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# padding # padding
attention_window = ( attention_window = (
self.config.attention_window self.config.attention_window
...@@ -662,6 +665,7 @@ class LongformerModel(RobertaModel): ...@@ -662,6 +665,7 @@ class LongformerModel(RobertaModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
output_attentions=output_attentions,
) )
# undo padding # undo padding
...@@ -699,6 +703,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): ...@@ -699,6 +703,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -721,7 +726,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): ...@@ -721,7 +726,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -759,6 +764,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): ...@@ -759,6 +764,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
...@@ -799,6 +805,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel): ...@@ -799,6 +805,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -818,7 +825,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel): ...@@ -818,7 +825,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -918,6 +925,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -918,6 +925,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
output_attentions=None,
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -940,7 +948,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -940,7 +948,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
...@@ -983,6 +991,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -983,6 +991,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1042,6 +1051,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -1042,6 +1051,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -1059,7 +1069,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -1059,7 +1069,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1141,6 +1151,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): ...@@ -1141,6 +1151,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
labels=None, labels=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1161,7 +1172,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): ...@@ -1161,7 +1172,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......
...@@ -165,7 +165,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin): ...@@ -165,7 +165,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
...@@ -295,7 +295,7 @@ class MMBTForClassification(nn.Module): ...@@ -295,7 +295,7 @@ class MMBTForClassification(nn.Module):
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
......
...@@ -137,8 +137,6 @@ class Attention(nn.Module): ...@@ -137,8 +137,6 @@ class Attention(nn.Module):
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = config.output_attentions
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
...@@ -160,7 +158,7 @@ class Attention(nn.Module): ...@@ -160,7 +158,7 @@ class Attention(nn.Module):
self.n_head = self.n_head - len(heads) self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, attention_mask=None, head_mask=None): def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
w = torch.matmul(q, k) w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / math.sqrt(v.size(-1)) w = w / math.sqrt(v.size(-1))
...@@ -181,7 +179,7 @@ class Attention(nn.Module): ...@@ -181,7 +179,7 @@ class Attention(nn.Module):
w = w * head_mask w = w * head_mask
outputs = [torch.matmul(w, v)] outputs = [torch.matmul(w, v)]
if self.output_attentions: if output_attentions:
outputs.append(w) outputs.append(w)
return outputs return outputs
...@@ -198,14 +196,14 @@ class Attention(nn.Module): ...@@ -198,14 +196,14 @@ class Attention(nn.Module):
else: else:
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, x, attention_mask=None, head_mask=None): def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2) query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query) query = self.split_heads(query)
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
attn_outputs = self._attn(query, key, value, attention_mask, head_mask) attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
a = attn_outputs[0] a = attn_outputs[0]
a = self.merge_heads(a) a = self.merge_heads(a)
...@@ -240,8 +238,10 @@ class Block(nn.Module): ...@@ -240,8 +238,10 @@ class Block(nn.Module):
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x, attention_mask=None, head_mask=None): def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
attn_outputs = self.attn(x, attention_mask=attention_mask, head_mask=head_mask) attn_outputs = self.attn(
x, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions,
)
a = attn_outputs[0] a = attn_outputs[0]
n = self.ln_1(x + a) n = self.ln_1(x + a)
...@@ -322,6 +322,8 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" ...@@ -322,6 +322,8 @@ OPENAI_GPT_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
Should the model returns attentions weights.
""" """
...@@ -332,7 +334,6 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" ...@@ -332,7 +334,6 @@ OPENAI_GPT_INPUTS_DOCSTRING = r"""
class OpenAIGPTModel(OpenAIGPTPreTrainedModel): class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
...@@ -364,6 +365,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -364,6 +365,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -375,7 +377,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -375,7 +377,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -394,6 +396,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -394,6 +396,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -449,9 +453,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -449,9 +453,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(hidden_states, attention_mask, head_mask[i]) outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
hidden_states = outputs[0] hidden_states = outputs[0]
if self.output_attentions: if output_attentions:
all_attentions = all_attentions + (outputs[1],) all_attentions = all_attentions + (outputs[1],)
# Add last layer # Add last layer
...@@ -461,7 +465,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -461,7 +465,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
outputs = (hidden_states.view(*output_shape),) outputs = (hidden_states.view(*output_shape),)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last hidden state, (all hidden states), (all attentions) return outputs # last hidden state, (all hidden states), (all attentions)
...@@ -492,6 +496,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -492,6 +496,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -516,7 +521,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -516,7 +521,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -542,6 +547,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -542,6 +547,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -593,6 +599,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -593,6 +599,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
mc_token_ids=None, mc_token_ids=None,
labels=None, labels=None,
mc_labels=None, mc_labels=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -631,7 +638,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -631,7 +638,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -671,6 +678,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -671,6 +678,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -318,7 +318,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -318,7 +318,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
do_output_attentions=False, output_attentions=False,
buckets=None, buckets=None,
**kwargs **kwargs
): ):
...@@ -444,7 +444,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -444,7 +444,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
if do_output_attentions is False: if output_attentions is False:
attention_probs = () attention_probs = ()
return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets)
...@@ -801,7 +801,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -801,7 +801,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
self.register_buffer("mask_value_float16", torch.tensor(-1e4)) self.register_buffer("mask_value_float16", torch.tensor(-1e4))
self.register_buffer("mask_value_float32", torch.tensor(-1e9)) self.register_buffer("mask_value_float32", torch.tensor(-1e9))
def forward(self, hidden_states, attention_mask=None, head_mask=None, do_output_attentions=False, **kwargs): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, **kwargs):
sequence_length = hidden_states.shape[1] sequence_length = hidden_states.shape[1]
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -921,7 +921,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -921,7 +921,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
if do_output_attentions is False: if output_attentions is False:
attention_probs = () attention_probs = ()
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
...@@ -1001,7 +1001,7 @@ class ReformerAttention(nn.Module): ...@@ -1001,7 +1001,7 @@ class ReformerAttention(nn.Module):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
do_output_attentions=False, output_attentions=False,
buckets=None, buckets=None,
): ):
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
...@@ -1012,7 +1012,7 @@ class ReformerAttention(nn.Module): ...@@ -1012,7 +1012,7 @@ class ReformerAttention(nn.Module):
head_mask=head_mask, head_mask=head_mask,
attention_mask=attention_mask, attention_mask=attention_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
do_output_attentions=do_output_attentions, output_attentions=output_attentions,
buckets=buckets, buckets=buckets,
) )
attention_output = self.output(self_attention_outputs.hidden_states) attention_output = self.output(self_attention_outputs.hidden_states)
...@@ -1139,7 +1139,7 @@ class ReformerLayer(nn.Module): ...@@ -1139,7 +1139,7 @@ class ReformerLayer(nn.Module):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
do_output_attentions=False, output_attentions=False,
): ):
with torch.no_grad(): with torch.no_grad():
# every forward pass we sample a different seed # every forward pass we sample a different seed
...@@ -1151,7 +1151,7 @@ class ReformerLayer(nn.Module): ...@@ -1151,7 +1151,7 @@ class ReformerLayer(nn.Module):
head_mask=head_mask, head_mask=head_mask,
attention_mask=attention_mask, attention_mask=attention_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
do_output_attentions=do_output_attentions, output_attentions=output_attentions,
) )
attn_output = attn_outputs.hidden_states attn_output = attn_outputs.hidden_states
...@@ -1257,7 +1257,7 @@ class _ReversibleFunction(Function): ...@@ -1257,7 +1257,7 @@ class _ReversibleFunction(Function):
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
do_output_hidden_states, do_output_hidden_states,
do_output_attentions, output_attentions,
): ):
all_buckets = () all_buckets = ()
...@@ -1274,13 +1274,13 @@ class _ReversibleFunction(Function): ...@@ -1274,13 +1274,13 @@ class _ReversibleFunction(Function):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=layer_head_mask, head_mask=layer_head_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
do_output_attentions=do_output_attentions, output_attentions=output_attentions,
) )
attn_output = layer_outputs.attn_output attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states hidden_states = layer_outputs.hidden_states
all_buckets = all_buckets + (layer_outputs.buckets,) all_buckets = all_buckets + (layer_outputs.buckets,)
if do_output_attentions: if output_attentions:
all_attentions.append(layer_outputs.attention_probs) all_attentions.append(layer_outputs.attention_probs)
# Add last layer # Add last layer
...@@ -1361,7 +1361,7 @@ class ReformerEncoder(nn.Module): ...@@ -1361,7 +1361,7 @@ class ReformerEncoder(nn.Module):
head_mask=None, head_mask=None,
num_hashes=None, num_hashes=None,
do_output_hidden_states=False, do_output_hidden_states=False,
do_output_attentions=False, output_attentions=False,
): ):
# hidden_states and attention lists to be filled if wished # hidden_states and attention lists to be filled if wished
all_hidden_states = [] all_hidden_states = []
...@@ -1378,7 +1378,7 @@ class ReformerEncoder(nn.Module): ...@@ -1378,7 +1378,7 @@ class ReformerEncoder(nn.Module):
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
do_output_hidden_states, do_output_hidden_states,
do_output_attentions, output_attentions,
) )
# Apply layer norm to concatenated hidden states # Apply layer norm to concatenated hidden states
...@@ -1549,7 +1549,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1549,7 +1549,7 @@ class ReformerModel(ReformerPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
num_hashes=None, num_hashes=None,
do_output_hidden_states=False, do_output_hidden_states=False,
do_output_attentions=False, output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -1561,7 +1561,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1561,7 +1561,7 @@ class ReformerModel(ReformerPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1582,8 +1582,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1582,8 +1582,7 @@ class ReformerModel(ReformerPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
# TODO(PVP): delete when PR to change output_attentions is made output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
do_output_attentions = self.config.output_attentions
do_output_hidden_states = self.config.output_hidden_states do_output_hidden_states = self.config.output_hidden_states
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
...@@ -1643,7 +1642,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1643,7 +1642,7 @@ class ReformerModel(ReformerPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
num_hashes=num_hashes, num_hashes=num_hashes,
do_output_hidden_states=do_output_hidden_states, do_output_hidden_states=do_output_hidden_states,
do_output_attentions=do_output_attentions, output_attentions=output_attentions,
) )
sequence_output = encoder_outputs.hidden_states sequence_output = encoder_outputs.hidden_states
...@@ -1655,7 +1654,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1655,7 +1654,7 @@ class ReformerModel(ReformerPreTrainedModel):
# TODO(PVP): Replace by named tuple after namedtuples are introduced in the library. # TODO(PVP): Replace by named tuple after namedtuples are introduced in the library.
if do_output_hidden_states is True: if do_output_hidden_states is True:
outputs = outputs + (encoder_outputs.all_hidden_states,) outputs = outputs + (encoder_outputs.all_hidden_states,)
if do_output_attentions is True: if output_attentions is True:
outputs = outputs + (encoder_outputs.all_attentions,) outputs = outputs + (encoder_outputs.all_attentions,)
return outputs return outputs
...@@ -1744,7 +1743,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -1744,7 +1743,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
num_hashes=None, num_hashes=None,
labels=None, labels=None,
do_output_hidden_states=False, do_output_hidden_states=False,
do_output_attentions=False, output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1764,7 +1763,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -1764,7 +1763,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``do_output_attentions=True``): all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -1793,7 +1792,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -1793,7 +1792,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
num_hashes=num_hashes, num_hashes=num_hashes,
do_output_hidden_states=do_output_hidden_states, do_output_hidden_states=do_output_hidden_states,
do_output_attentions=do_output_attentions, output_attentions=output_attentions,
) )
sequence_output = reformer_outputs[0] sequence_output = reformer_outputs[0]
......
...@@ -185,6 +185,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -185,6 +185,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -207,7 +208,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -207,7 +208,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -241,6 +242,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -241,6 +242,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
...@@ -306,6 +308,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel): ...@@ -306,6 +308,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -325,7 +328,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel): ...@@ -325,7 +328,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -352,6 +355,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel): ...@@ -352,6 +355,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
...@@ -398,6 +402,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -398,6 +402,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -418,7 +423,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -418,7 +423,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -451,6 +456,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -451,6 +456,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -497,6 +503,7 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -497,6 +503,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -514,7 +521,7 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -514,7 +521,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -542,6 +549,7 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -542,6 +549,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -616,6 +624,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): ...@@ -616,6 +624,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
output_attentions=None,
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -640,7 +649,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): ...@@ -640,7 +649,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -674,6 +683,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): ...@@ -674,6 +683,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -195,7 +195,6 @@ class T5Attention(nn.Module): ...@@ -195,7 +195,6 @@ class T5Attention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model self.d_model = config.d_model
self.d_kv = config.d_kv self.d_kv = config.d_kv
...@@ -300,6 +299,7 @@ class T5Attention(nn.Module): ...@@ -300,6 +299,7 @@ class T5Attention(nn.Module):
head_mask=None, head_mask=None,
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False,
): ):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
...@@ -386,7 +386,7 @@ class T5Attention(nn.Module): ...@@ -386,7 +386,7 @@ class T5Attention(nn.Module):
outputs = (context,) + present_key_value_state outputs = (context,) + present_key_value_state
if self.output_attentions: if output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
outputs = outputs + (position_bias,) outputs = outputs + (position_bias,)
...@@ -408,6 +408,7 @@ class T5LayerSelfAttention(nn.Module): ...@@ -408,6 +408,7 @@ class T5LayerSelfAttention(nn.Module):
head_mask=None, head_mask=None,
past_key_value_state=None, past_key_value_state=None,
use_cache=False, use_cache=False,
output_attentions=False,
): ):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
...@@ -417,6 +418,7 @@ class T5LayerSelfAttention(nn.Module): ...@@ -417,6 +418,7 @@ class T5LayerSelfAttention(nn.Module):
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
y = attention_output[0] y = attention_output[0]
layer_output = hidden_states + self.dropout(y) layer_output = hidden_states + self.dropout(y)
...@@ -441,6 +443,7 @@ class T5LayerCrossAttention(nn.Module): ...@@ -441,6 +443,7 @@ class T5LayerCrossAttention(nn.Module):
past_key_value_state=None, past_key_value_state=None,
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False,
): ):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
...@@ -452,6 +455,7 @@ class T5LayerCrossAttention(nn.Module): ...@@ -452,6 +455,7 @@ class T5LayerCrossAttention(nn.Module):
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions,
) )
y = attention_output[0] y = attention_output[0]
layer_output = hidden_states + self.dropout(y) layer_output = hidden_states + self.dropout(y)
...@@ -481,6 +485,7 @@ class T5Block(nn.Module): ...@@ -481,6 +485,7 @@ class T5Block(nn.Module):
head_mask=None, head_mask=None,
past_key_value_state=None, past_key_value_state=None,
use_cache=False, use_cache=False,
output_attentions=False,
): ):
if past_key_value_state is not None: if past_key_value_state is not None:
...@@ -506,6 +511,7 @@ class T5Block(nn.Module): ...@@ -506,6 +511,7 @@ class T5Block(nn.Module):
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state, past_key_value_state=self_attn_past_key_value_state,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
...@@ -527,6 +533,7 @@ class T5Block(nn.Module): ...@@ -527,6 +533,7 @@ class T5Block(nn.Module):
past_key_value_state=cross_attn_past_key_value_state, past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
# Combine self attn and cross attn key value states # Combine self attn and cross attn key value states
...@@ -622,7 +629,6 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -622,7 +629,6 @@ class T5PreTrainedModel(PreTrainedModel):
class T5Stack(T5PreTrainedModel): class T5Stack(T5PreTrainedModel):
def __init__(self, config, embed_tokens=None): def __init__(self, config, embed_tokens=None):
super().__init__(config) super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -655,8 +661,11 @@ class T5Stack(T5PreTrainedModel): ...@@ -655,8 +661,11 @@ class T5Stack(T5PreTrainedModel):
head_mask=None, head_mask=None,
past_key_value_states=None, past_key_value_states=None,
use_cache=False, use_cache=False,
output_attentions=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -730,6 +739,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -730,6 +739,7 @@ class T5Stack(T5PreTrainedModel):
head_mask=head_mask[i], head_mask=head_mask[i],
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
...@@ -738,13 +748,13 @@ class T5Stack(T5PreTrainedModel): ...@@ -738,13 +748,13 @@ class T5Stack(T5PreTrainedModel):
if i == 0: if i == 0:
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2] position_bias = layer_outputs[3 if output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3] encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
# append next layer key value states # append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if self.output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -760,7 +770,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -760,7 +770,7 @@ class T5Stack(T5PreTrainedModel):
outputs = outputs + (present_key_value_states,) outputs = outputs + (present_key_value_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (presents,) (all hidden states), (all attentions) return outputs # last-layer hidden state, (presents,) (all hidden states), (all attentions)
...@@ -887,6 +897,7 @@ class T5Model(T5PreTrainedModel): ...@@ -887,6 +897,7 @@ class T5Model(T5PreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
head_mask=None, head_mask=None,
output_attentions=None,
): ):
r""" r"""
Return: Return:
...@@ -903,7 +914,7 @@ class T5Model(T5PreTrainedModel): ...@@ -903,7 +914,7 @@ class T5Model(T5PreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -925,7 +936,11 @@ class T5Model(T5PreTrainedModel): ...@@ -925,7 +936,11 @@ class T5Model(T5PreTrainedModel):
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
...@@ -948,6 +963,7 @@ class T5Model(T5PreTrainedModel): ...@@ -948,6 +963,7 @@ class T5Model(T5PreTrainedModel):
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
if use_cache is True: if use_cache is True:
...@@ -1007,6 +1023,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1007,6 +1023,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
head_mask=None, head_mask=None,
output_attentions=None,
**kwargs **kwargs
): ):
r""" r"""
...@@ -1033,7 +1050,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1033,7 +1050,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention.
...@@ -1066,7 +1083,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1066,7 +1083,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if encoder_outputs is None: if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed # Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
...@@ -1094,6 +1115,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1094,6 +1115,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
) )
# insert decoder past at right place # insert decoder past at right place
......
...@@ -29,6 +29,7 @@ from .modeling_tf_utils import ( ...@@ -29,6 +29,7 @@ from .modeling_tf_utils import (
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -158,7 +159,6 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -158,7 +159,6 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads) "heads (%d)" % (config.hidden_size, config.num_attention_heads)
) )
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
...@@ -182,7 +182,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -182,7 +182,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask, output_attentions = inputs
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -222,7 +222,9 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -222,7 +222,9 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
context_layer, (batch_size, -1, self.all_head_size) context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size) ) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) outputs = (
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)
return outputs return outputs
...@@ -259,7 +261,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -259,7 +261,7 @@ class TFAlbertAttention(TFBertSelfAttention):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, inputs, training=False):
input_tensor, attention_mask, head_mask = inputs input_tensor, attention_mask, head_mask, output_attentions = inputs
batch_size = shape_list(input_tensor)[0] batch_size = shape_list(input_tensor)[0]
mixed_query_layer = self.query(input_tensor) mixed_query_layer = self.query(input_tensor)
...@@ -299,7 +301,9 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -299,7 +301,9 @@ class TFAlbertAttention(TFBertSelfAttention):
context_layer, (batch_size, -1, self.all_head_size) context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size) ) # (batch_size, seq_len_q, all_head_size)
self_outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) self_outputs = (
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)
hidden_states = self_outputs[0] hidden_states = self_outputs[0]
...@@ -335,9 +339,11 @@ class TFAlbertLayer(tf.keras.layers.Layer): ...@@ -335,9 +339,11 @@ class TFAlbertLayer(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask, output_attentions = inputs
attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training) attention_outputs = self.attention(
[hidden_states, attention_mask, head_mask, output_attentions], training=training
)
ffn_output = self.ffn(attention_outputs[0]) ffn_output = self.ffn(attention_outputs[0])
ffn_output = self.activation(ffn_output) ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output) ffn_output = self.ffn_output(ffn_output)
...@@ -354,23 +360,24 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer): ...@@ -354,23 +360,24 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.albert_layers = [ self.albert_layers = [
TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num) TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num)
] ]
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask, output_attentions = inputs
layer_hidden_states = () layer_hidden_states = ()
layer_attentions = () layer_attentions = ()
for layer_index, albert_layer in enumerate(self.albert_layers): for layer_index, albert_layer in enumerate(self.albert_layers):
layer_output = albert_layer([hidden_states, attention_mask, head_mask[layer_index]], training=training) layer_output = albert_layer(
[hidden_states, attention_mask, head_mask[layer_index], output_attentions], training=training
)
hidden_states = layer_output[0] hidden_states = layer_output[0]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
layer_attentions = layer_attentions + (layer_output[1],) layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states: if self.output_hidden_states:
...@@ -379,7 +386,7 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer): ...@@ -379,7 +386,7 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (layer_hidden_states,) outputs = outputs + (layer_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (layer_attentions,) outputs = outputs + (layer_attentions,)
# last-layer hidden state, (layer hidden states), (layer attentions) # last-layer hidden state, (layer hidden states), (layer attentions)
return outputs return outputs
...@@ -390,7 +397,6 @@ class TFAlbertTransformer(tf.keras.layers.Layer): ...@@ -390,7 +397,6 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.embedding_hidden_mapping_in = tf.keras.layers.Dense( self.embedding_hidden_mapping_in = tf.keras.layers.Dense(
config.hidden_size, config.hidden_size,
...@@ -403,7 +409,7 @@ class TFAlbertTransformer(tf.keras.layers.Layer): ...@@ -403,7 +409,7 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
] ]
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask, output_attentions = inputs
hidden_states = self.embedding_hidden_mapping_in(hidden_states) hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = () all_attentions = ()
...@@ -423,12 +429,13 @@ class TFAlbertTransformer(tf.keras.layers.Layer): ...@@ -423,12 +429,13 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group], head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
output_attentions,
], ],
training=training, training=training,
) )
hidden_states = layer_group_output[0] hidden_states = layer_group_output[0]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
all_attentions = all_attentions + layer_group_output[-1] all_attentions = all_attentions + layer_group_output[-1]
if self.output_hidden_states: if self.output_hidden_states:
...@@ -437,7 +444,7 @@ class TFAlbertTransformer(tf.keras.layers.Layer): ...@@ -437,7 +444,7 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
# last-layer hidden state, (all hidden states), (all attentions) # last-layer hidden state, (all hidden states), (all attentions)
...@@ -494,6 +501,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -494,6 +501,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions
self.embeddings = TFAlbertEmbeddings(config, name="embeddings") self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
self.encoder = TFAlbertTransformer(config, name="encoder") self.encoder = TFAlbertTransformer(config, name="encoder")
...@@ -525,6 +533,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -525,6 +533,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -534,7 +543,8 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -534,7 +543,8 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -542,10 +552,13 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -542,10 +552,13 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -588,7 +601,9 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -588,7 +601,9 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions], training=training
)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output[:, 0]) pooled_output = self.pooler(sequence_output[:, 0])
...@@ -704,7 +719,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -704,7 +719,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -756,7 +771,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -756,7 +771,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
...@@ -816,7 +831,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): ...@@ -816,7 +831,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -871,6 +886,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -871,6 +886,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -889,7 +905,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -889,7 +905,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -916,6 +932,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -916,6 +932,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -959,6 +976,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -959,6 +976,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -975,7 +993,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -975,7 +993,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1001,6 +1019,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -1001,6 +1019,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -1046,6 +1065,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1046,6 +1065,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
is_impossible=None, is_impossible=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1069,7 +1089,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1069,7 +1089,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1100,6 +1120,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1100,6 +1120,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -1155,6 +1176,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1155,6 +1176,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1174,7 +1196,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1174,7 +1196,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1203,7 +1225,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1203,7 +1225,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -1211,7 +1234,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1211,7 +1234,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -1234,6 +1258,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1234,6 +1258,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
flat_position_ids, flat_position_ids,
head_mask, head_mask,
inputs_embeds, inputs_embeds,
output_attentions,
] ]
outputs = self.albert(flat_inputs, training=training) outputs = self.albert(flat_inputs, training=training)
......
...@@ -29,6 +29,7 @@ from .modeling_tf_utils import ( ...@@ -29,6 +29,7 @@ from .modeling_tf_utils import (
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -211,7 +212,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -211,7 +212,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads) "heads (%d)" % (config.hidden_size, config.num_attention_heads)
) )
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
...@@ -235,7 +235,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -235,7 +235,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask, output_attentions = inputs
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -275,7 +275,10 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -275,7 +275,10 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
context_layer, (batch_size, -1, self.all_head_size) context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size) ) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) outputs = (
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)
return outputs return outputs
...@@ -307,9 +310,11 @@ class TFBertAttention(tf.keras.layers.Layer): ...@@ -307,9 +310,11 @@ class TFBertAttention(tf.keras.layers.Layer):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, inputs, training=False):
input_tensor, attention_mask, head_mask = inputs input_tensor, attention_mask, head_mask, output_attentions = inputs
self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training) self_outputs = self.self_attention(
[input_tensor, attention_mask, head_mask, output_attentions], training=training
)
attention_output = self.dense_output([self_outputs[0], input_tensor], training=training) attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -358,9 +363,11 @@ class TFBertLayer(tf.keras.layers.Layer): ...@@ -358,9 +363,11 @@ class TFBertLayer(tf.keras.layers.Layer):
self.bert_output = TFBertOutput(config, name="output") self.bert_output = TFBertOutput(config, name="output")
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask, output_attentions = inputs
attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training) attention_outputs = self.attention(
[hidden_states, attention_mask, head_mask, output_attentions], training=training
)
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.bert_output([intermediate_output, attention_output], training=training) layer_output = self.bert_output([intermediate_output, attention_output], training=training)
...@@ -371,12 +378,11 @@ class TFBertLayer(tf.keras.layers.Layer): ...@@ -371,12 +378,11 @@ class TFBertLayer(tf.keras.layers.Layer):
class TFBertEncoder(tf.keras.layers.Layer): class TFBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask, output_attentions = inputs
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
...@@ -384,10 +390,12 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -384,10 +390,12 @@ class TFBertEncoder(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training) layer_outputs = layer_module(
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer # Add last layer
...@@ -397,7 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -397,7 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions) return outputs # outputs, (hidden states), (attentions)
...@@ -489,6 +497,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -489,6 +497,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions
self.embeddings = TFBertEmbeddings(config, name="embeddings") self.embeddings = TFBertEmbeddings(config, name="embeddings")
self.encoder = TFBertEncoder(config, name="encoder") self.encoder = TFBertEncoder(config, name="encoder")
...@@ -515,6 +524,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -515,6 +524,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -524,7 +534,8 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -524,7 +534,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -532,10 +543,13 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -532,10 +543,13 @@ class TFBertMainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -578,7 +592,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -578,7 +592,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions], training=training
)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
...@@ -697,7 +713,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -697,7 +713,7 @@ class TFBertModel(TFBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -749,7 +765,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -749,7 +765,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -803,7 +819,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -803,7 +819,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -853,7 +869,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -853,7 +869,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -910,6 +926,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -910,6 +926,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -928,7 +945,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -928,7 +945,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -955,6 +972,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -955,6 +972,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -1006,6 +1024,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1006,6 +1024,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1025,7 +1044,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1025,7 +1044,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1054,7 +1073,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1054,7 +1073,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -1062,7 +1082,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1062,7 +1082,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -1085,6 +1106,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1085,6 +1106,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
flat_position_ids, flat_position_ids,
head_mask, head_mask,
inputs_embeds, inputs_embeds,
output_attentions,
] ]
outputs = self.bert(flat_inputs, training=training) outputs = self.bert(flat_inputs, training=training)
...@@ -1130,6 +1152,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1130,6 +1152,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1146,7 +1169,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1146,7 +1169,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1172,6 +1195,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1172,6 +1195,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -1218,6 +1242,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1218,6 +1242,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
is_impossible=None, is_impossible=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1241,7 +1266,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1241,7 +1266,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -1270,6 +1295,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1270,6 +1295,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
......
...@@ -23,7 +23,13 @@ import tensorflow as tf ...@@ -23,7 +23,13 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list from .modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
...@@ -78,9 +84,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N ...@@ -78,9 +84,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
class TFMultiHeadAttention(tf.keras.layers.Layer): class TFMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs): def __init__(self, d_model_size, num_heads, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = output_attentions
self.num_heads = num_heads self.num_heads = num_heads
self.d_model_size = d_model_size self.d_model_size = d_model_size
...@@ -97,7 +102,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -97,7 +102,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
batch_size = shape_list(q)[0] batch_size = shape_list(q)[0]
q = self.Wq(q) q = self.Wq(q)
...@@ -114,13 +119,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -114,13 +119,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
v = tf.concat((past_value, v), axis=-2) v = tf.concat((past_value, v), axis=-2)
# to cope with keras serialization # to cope with keras serialization
# we need to cast `use_cache` to correct bool use_cache = cast_bool_to_primitive(use_cache, True)
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True: if use_cache is True:
present = tf.stack((k, v), axis=0) present = tf.stack((k, v), axis=0)
...@@ -134,7 +133,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -134,7 +133,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
output = self.dense(original_size_attention) output = self.dense(original_size_attention)
outputs = (output, present) outputs = (output, present)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (attn,) outputs = outputs + (attn,)
return outputs return outputs
...@@ -147,14 +146,10 @@ def point_wise_feed_forward_network(d_model_size, dff, name=""): ...@@ -147,14 +146,10 @@ def point_wise_feed_forward_network(d_model_size, dff, name=""):
class TFEncoderLayer(tf.keras.layers.Layer): class TFEncoderLayer(tf.keras.layers.Layer):
def __init__( def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.multi_head_attention = TFMultiHeadAttention( self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
d_model_size, num_heads, output_attentions, name="multi_head_attention"
)
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn") self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1") self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
...@@ -164,10 +159,11 @@ class TFEncoderLayer(tf.keras.layers.Layer): ...@@ -164,10 +159,11 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout2 = tf.keras.layers.Dropout(rate) self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask, use_cache = inputs x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
normed = self.layernorm1(x) normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention( attn_outputs = self.multi_head_attention(
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training [normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],
training=training,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output, training=training) attn_output = self.dropout1(attn_output, training=training)
...@@ -208,7 +204,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -208,7 +204,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
config.dff, config.dff,
config.resid_pdrop, config.resid_pdrop,
config.layer_norm_epsilon, config.layer_norm_epsilon,
config.output_attentions,
name="h_._{}".format(i), name="h_._{}".format(i),
) )
for i in range(config.n_layer) for i in range(config.n_layer)
...@@ -237,6 +232,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -237,6 +232,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=True,
output_attentions=None,
training=False, training=False,
): ):
...@@ -249,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -249,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs." output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
...@@ -259,10 +256,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -259,10 +256,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
# If using past key value states, only the last tokens # If using past key value states, only the last tokens
# should be given as an input # should be given as an input
if past is not None: if past is not None:
...@@ -349,13 +349,16 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -349,13 +349,16 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
for i, (h, layer_past) in enumerate(zip(self.h, past)): for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training) outputs = h(
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
training=training,
)
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache is True: if use_cache is True:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states) hidden_states = self.layernorm(hidden_states)
...@@ -368,7 +371,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -368,7 +371,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
outputs = outputs + (presents,) outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
...@@ -489,7 +492,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel): ...@@ -489,7 +492,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
...@@ -569,7 +572,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel): ...@@ -569,7 +572,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
......
...@@ -31,6 +31,7 @@ from .modeling_tf_utils import ( ...@@ -31,6 +31,7 @@ from .modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -186,7 +187,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -186,7 +187,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.dim = config.dim self.dim = config.dim
self.dropout = tf.keras.layers.Dropout(config.attention_dropout) self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.output_attentions = config.output_attentions
assert self.dim % self.n_heads == 0 assert self.dim % self.n_heads == 0
...@@ -224,7 +224,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -224,7 +224,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
context: tf.Tensor(bs, seq_length, dim) context: tf.Tensor(bs, seq_length, dim)
Contextualized layer. Optional: only if `output_attentions=True` Contextualized layer. Optional: only if `output_attentions=True`
""" """
query, key, value, mask, head_mask = inputs query, key, value, mask, head_mask, output_attentions = inputs
bs, q_length, dim = shape_list(query) bs, q_length, dim = shape_list(query)
k_length = shape_list(key)[1] k_length = shape_list(key)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
...@@ -263,7 +263,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -263,7 +263,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
context = unshape(context) # (bs, q_length, dim) context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
return (context, weights) return (context, weights)
else: else:
return (context,) return (context,)
...@@ -303,7 +303,6 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -303,7 +303,6 @@ class TFTransformerBlock(tf.keras.layers.Layer):
self.hidden_dim = config.hidden_dim self.hidden_dim = config.hidden_dim
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation = config.activation self.activation = config.activation
self.output_attentions = config.output_attentions
assert config.dim % config.n_heads == 0 assert config.dim % config.n_heads == 0
...@@ -327,11 +326,11 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -327,11 +326,11 @@ class TFTransformerBlock(tf.keras.layers.Layer):
ffn_output: tf.Tensor(bs, seq_length, dim) ffn_output: tf.Tensor(bs, seq_length, dim)
The output of the transformer block contextualization. The output of the transformer block contextualization.
""" """
x, attn_mask, head_mask = inputs x, attn_mask, head_mask, output_attentions = inputs
# Self-Attention # Self-Attention
sa_output = self.attention([x, x, x, attn_mask, head_mask], training=training) sa_output = self.attention([x, x, x, attn_mask, head_mask, output_attentions], training=training)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
# assert type(sa_output) == tuple # assert type(sa_output) == tuple
...@@ -343,7 +342,7 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -343,7 +342,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output,) output = (ffn_output,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
output = (sa_weights,) + output output = (sa_weights,) + output
return output return output
...@@ -352,7 +351,6 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -352,7 +351,6 @@ class TFTransformer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.n_layers = config.n_layers self.n_layers = config.n_layers
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)] self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
...@@ -377,7 +375,7 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -377,7 +375,7 @@ class TFTransformer(tf.keras.layers.Layer):
Tuple of length n_layers with the attention weights from each layer Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True Optional: only if output_attentions=True
""" """
x, attn_mask, head_mask = inputs x, attn_mask, head_mask, output_attentions = inputs
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
...@@ -387,10 +385,10 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -387,10 +385,10 @@ class TFTransformer(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module([hidden_state, attn_mask, head_mask[i]], training=training) layer_outputs = layer_module([hidden_state, attn_mask, head_mask[i], output_attentions], training=training)
hidden_state = layer_outputs[-1] hidden_state = layer_outputs[-1]
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
assert len(layer_outputs) == 2 assert len(layer_outputs) == 2
attentions = layer_outputs[0] attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,) all_attentions = all_attentions + (attentions,)
...@@ -404,7 +402,7 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -404,7 +402,7 @@ class TFTransformer(tf.keras.layers.Layer):
outputs = (hidden_state,) outputs = (hidden_state,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
...@@ -416,6 +414,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -416,6 +414,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
self.transformer = TFTransformer(config, name="transformer") # Encoder self.transformer = TFTransformer(config, name="transformer") # Encoder
...@@ -429,22 +428,28 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -429,22 +428,28 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, training=False): def call(
self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, training=False
):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs." output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
assert len(inputs) <= 5, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 4, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 5, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -469,7 +474,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -469,7 +474,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim) embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training) tfmr_output = self.transformer(
[embedding_output, attention_mask, head_mask, output_attentions], training=training
)
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
...@@ -566,7 +573,7 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel): ...@@ -566,7 +573,7 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -612,7 +619,6 @@ class TFDistilBertLMHead(tf.keras.layers.Layer): ...@@ -612,7 +619,6 @@ class TFDistilBertLMHead(tf.keras.layers.Layer):
class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel): class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -640,7 +646,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel): ...@@ -640,7 +646,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -694,7 +700,14 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -694,7 +700,14 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call( def call(
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False, self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -712,7 +725,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -712,7 +725,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -736,6 +749,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -736,6 +749,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -772,7 +786,14 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -772,7 +786,14 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call( def call(
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False, self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -788,7 +809,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -788,7 +809,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -812,6 +833,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -812,6 +833,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
training=training, training=training,
) )
...@@ -861,7 +883,14 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -861,7 +883,14 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call( def call(
self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False, self,
inputs,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
training=False,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -880,7 +909,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -880,7 +909,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -979,6 +1008,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -979,6 +1008,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
is_impossible=None, is_impossible=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -1002,7 +1032,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1002,7 +1032,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
......
...@@ -235,6 +235,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -235,6 +235,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -244,7 +245,8 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -244,7 +245,8 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -252,10 +254,13 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -252,10 +254,13 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -278,7 +283,9 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -278,7 +283,9 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
if hasattr(self, "embeddings_project"): if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=training) hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.encoder([hidden_states, extended_attention_mask, head_mask], training=training) hidden_states = self.encoder(
[hidden_states, extended_attention_mask, head_mask, output_attentions], training=training
)
return hidden_states return hidden_states
...@@ -372,7 +379,7 @@ class TFElectraModel(TFElectraPreTrainedModel): ...@@ -372,7 +379,7 @@ class TFElectraModel(TFElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -421,6 +428,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -421,6 +428,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -433,7 +441,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -433,7 +441,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -452,7 +460,14 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -452,7 +460,14 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
""" """
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
training=training,
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output) logits = self.discriminator_predictions(discriminator_sequence_output)
...@@ -514,6 +529,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel): ...@@ -514,6 +529,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -526,7 +542,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel): ...@@ -526,7 +542,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -546,7 +562,14 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel): ...@@ -546,7 +562,14 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
""" """
generator_hidden_states = self.electra( generator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions=output_attentions,
training=training,
) )
generator_sequence_output = generator_hidden_states[0] generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output, training=training) prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
...@@ -584,6 +607,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -584,6 +607,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -600,7 +624,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -600,7 +624,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -621,7 +645,14 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -621,7 +645,14 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
""" """
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
training=training,
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
discriminator_sequence_output = self.dropout(discriminator_sequence_output) discriminator_sequence_output = self.dropout(discriminator_sequence_output)
...@@ -665,6 +696,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -665,6 +696,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
cls_index=None, cls_index=None,
p_mask=None, p_mask=None,
is_impossible=None, is_impossible=None,
output_attentions=None,
training=False, training=False,
): ):
r""" r"""
...@@ -688,7 +720,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -688,7 +720,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`: :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
...@@ -711,7 +743,14 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -711,7 +743,14 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
""" """
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
training=training,
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment